4 Commits

Author SHA1 Message Date
zxq5 7e001d8769 idk 2026-06-03 19:12:23 +01:00
zxq5 2f34976f3e Merge remote-tracking branch 'origin/dev' into dev 2026-04-11 00:10:09 +01:00
zxq5 d1208f7e39 frontend v0.4.1-2
- added invite section to UI and some general bug fixes
2026-04-11 00:09:47 +01:00
zxq5 d6ba875297 addedd RELEASE_MODE=1 to run var to prevent crash in absence of .env
file
2026-04-08 00:05:54 +01:00
42 changed files with 1026 additions and 280 deletions
+1 -5
View File
@@ -1,12 +1,9 @@
<?xml version="1.0" encoding="utf-8"?> <?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android" <manifest xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"> xmlns:tools="http://tools.android.com/tools">
<uses-permission android:name="android.permission.INTERNET" /> <uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.POST_NOTIFICATIONS"/> <uses-permission android:name="android.permission.POST_NOTIFICATIONS"/>
<uses-permission android:name="android.permission.FOREGROUND_SERVICE"/>
<uses-permission android:name="android.permission.FOREGROUND_SERVICE_DATA_SYNC"/>
<application <application
android:name=".ChatApplication" android:name=".ChatApplication"
@@ -22,7 +19,6 @@
<service <service
android:name=".core.service.MessageStreamService" android:name=".core.service.MessageStreamService"
android:foregroundServiceType="dataSync"
android:exported="false"/> android:exported="false"/>
<activity <activity
@@ -1,9 +1,13 @@
package dev.zxq5.chatapp.android package dev.zxq5.chatapp.android
import android.Manifest
import android.os.Build
import android.os.Bundle import android.os.Bundle
import androidx.activity.ComponentActivity import androidx.activity.ComponentActivity
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.compose.setContent import androidx.activity.compose.setContent
import androidx.activity.enableEdgeToEdge import androidx.activity.enableEdgeToEdge
import androidx.activity.result.contract.ActivityResultContracts
import androidx.compose.foundation.border import androidx.compose.foundation.border
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxSize
@@ -62,14 +66,37 @@ class MainActivity : ComponentActivity() {
val currentScreen by chatViewModel.currentScreen.collectAsState() val currentScreen by chatViewModel.currentScreen.collectAsState()
val selectedChannelId by chatViewModel.channelId.collectAsState() val selectedChannelId by chatViewModel.channelId.collectAsState()
// Permission request launcher
val launcher = rememberLauncherForActivityResult(
contract = ActivityResultContracts.RequestPermission(),
onResult = { isGranted ->
if (isGranted && authState == AuthState.Authenticated) {
MessageStreamService.start(this@MainActivity)
}
}
)
LaunchedEffect(authState) { LaunchedEffect(authState) {
when (authState) { when (authState) {
AuthState.Authenticated -> MessageStreamService.start(this@MainActivity) AuthState.Authenticated -> {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) {
launcher.launch(Manifest.permission.POST_NOTIFICATIONS)
}
MessageStreamService.start(this@MainActivity)
chatViewModel.loadAccessibleChannels()
}
AuthState.Unauthenticated -> MessageStreamService.stop(this@MainActivity) AuthState.Unauthenticated -> MessageStreamService.stop(this@MainActivity)
AuthState.AwaitingTotp -> {} AuthState.AwaitingTotp -> {}
} }
} }
LaunchedEffect(Unit) {
chatViewModel.onUnauthorized = {
authViewModel.logout()
chatViewModel.clearChat()
}
}
LaunchedEffect(Unit) { LaunchedEffect(Unit) {
intent.getIntExtra("channel_id", -1).takeIf { it != -1 }?.let { intent.getIntExtra("channel_id", -1).takeIf { it != -1 }?.let {
chatViewModel.switchChannel(it.toLong()) chatViewModel.switchChannel(it.toLong())
@@ -1,7 +1,9 @@
@file:OptIn(ExperimentalUuidApi::class)
package dev.zxq5.chatapp.android.api package dev.zxq5.chatapp.android.api
import dev.zxq5.chatapp.android.BuildConfig.BASE_URL import dev.zxq5.chatapp.android.BuildConfig.BASE_URL
import dev.zxq5.chatapp.android.api.model.Message import dev.zxq5.chatapp.android.api.model.ChatEvent
import dev.zxq5.chatapp.android.api.model.SendMessage import dev.zxq5.chatapp.android.api.model.SendMessage
import dev.zxq5.chatapp.android.api.model.SpaceDto import dev.zxq5.chatapp.android.api.model.SpaceDto
import io.ktor.client.HttpClient import io.ktor.client.HttpClient
@@ -25,6 +27,8 @@ import kotlinx.coroutines.flow.flow
import kotlinx.serialization.json.Json import kotlinx.serialization.json.Json
import kotlin.time.Clock import kotlin.time.Clock
import kotlin.time.ExperimentalTime import kotlin.time.ExperimentalTime
import kotlin.uuid.ExperimentalUuidApi
import kotlin.uuid.Uuid
class ChatClient(private val token: String) { class ChatClient(private val token: String) {
@@ -45,18 +49,18 @@ class ChatClient(private val token: String) {
suspend fun sendMessage(channelId: Long, userId: Int, text: String) { suspend fun sendMessage(channelId: Long, userId: Int, text: String) {
http.post("${BASE_URL}/api/chat/$channelId") { http.post("${BASE_URL}/api/chat/$channelId") {
contentType(ContentType.Application.Json) contentType(ContentType.Application.Json)
setBody(SendMessage(user_id = userId, text = text, timestamp = Clock.System.now())) setBody(SendMessage(id = Uuid.random(), user_id = userId, text = text, timestamp = Clock.System.now()))
} }
} }
fun messageStream(channelId: Long): Flow<Message> = flow { fun eventStream(channelId: Long): Flow<ChatEvent> = flow {
http.prepareGet("${BASE_URL}/api/events/$channelId").execute { response -> http.prepareGet("${BASE_URL}/api/events/$channelId").execute { response ->
val channel = response.bodyAsChannel() val channel = response.bodyAsChannel()
while (!channel.isClosedForRead) { while (!channel.isClosedForRead) {
val line = channel.readLine() ?: break val line = channel.readLine() ?: break
if (line.startsWith("data:")) { if (line.startsWith("data:")) {
val json = line.removePrefix("data:").trim() val json = line.removePrefix("data:").trim()
runCatching { Json.decodeFromString<Message>(json) } runCatching { Json.decodeFromString<ChatEvent>(json) }
.onSuccess { emit(it) } .onSuccess { emit(it) }
} }
} }
@@ -4,6 +4,7 @@ import android.util.Log
import dev.zxq5.chatapp.android.BuildConfig.BASE_URL import dev.zxq5.chatapp.android.BuildConfig.BASE_URL
import dev.zxq5.chatapp.android.api.model.AccountDeleteRequest import dev.zxq5.chatapp.android.api.model.AccountDeleteRequest
import dev.zxq5.chatapp.android.api.model.DisplayNameRequest import dev.zxq5.chatapp.android.api.model.DisplayNameRequest
import dev.zxq5.chatapp.android.api.model.InviteRequest
import dev.zxq5.chatapp.android.api.model.PasswordChangeRequest import dev.zxq5.chatapp.android.api.model.PasswordChangeRequest
import dev.zxq5.chatapp.android.api.model.QrResponse import dev.zxq5.chatapp.android.api.model.QrResponse
import dev.zxq5.chatapp.android.api.model.TOTPSixDigitCode import dev.zxq5.chatapp.android.api.model.TOTPSixDigitCode
@@ -45,6 +46,22 @@ class SettingsClient(private val token: String) {
} }
} }
suspend fun createInvite(request: InviteRequest): ApiResult<String> {
return try {
val response = http.post("${BASE_URL}/api/invite") {
contentType(ContentType.Application.Json)
setBody(request)
}
if (response.status.isSuccess()) {
ApiResult.Success(response.body<String>())
} else {
ApiResult.HttpError(response.status.value, "Failed to create invite")
}
} catch (e: Exception) {
ApiResult.NetworkError(e.localizedMessage ?: "Network error")
}
}
suspend fun getTotpQr(password: String): ApiResult<QrResponse> { suspend fun getTotpQr(password: String): ApiResult<QrResponse> {
return try { return try {
val response = http.post("${BASE_URL}/api/totp.jpg") { val response = http.post("${BASE_URL}/api/totp.jpg") {
@@ -0,0 +1,60 @@
@file:OptIn(ExperimentalUuidApi::class, ExperimentalTime::class)
package dev.zxq5.chatapp.android.api.model
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.SerialName
import kotlinx.serialization.json.JsonClassDiscriminator
import kotlin.time.ExperimentalTime
import kotlin.time.Instant
import kotlinx.serialization.Serializable
import kotlin.uuid.ExperimentalUuidApi
import kotlin.uuid.Uuid
@OptIn(ExperimentalSerializationApi::class)
@Serializable
@JsonClassDiscriminator("type")
sealed class ChatEvent {
@Serializable
@SerialName("SendMessage")
data class SendMessage(
val data: Message
) : ChatEvent()
@Serializable
@SerialName("EditMessage")
data class EditMessage(
val data: EditMessageContent
) : ChatEvent()
@Serializable
@SerialName("MessageAppendContent")
data class MessageAppendContent(
val data: AppendContent
) : ChatEvent()
}
// tuple variants like (i64, ChatMsg) and (i64, String)
// need wrapper classes since kotlinx can't deserialise
// bare JSON arrays into data classes directly
@Serializable
data class EditMessageContent(
val id: Uuid,
val message: Message
)
@Serializable
data class AppendContent (
val id: Uuid,
val content: String
)
@Serializable
data class Message (
val id: Uuid,
val user_id: Int,
val display_name: String,
val text: String,
val timestamp: Instant
)
@@ -0,0 +1,13 @@
package dev.zxq5.chatapp.android.api.model
import kotlinx.serialization.Serializable
import kotlin.time.ExperimentalTime
import kotlin.time.Instant
@Serializable
data class InviteRequest @OptIn(ExperimentalTime::class) constructor(
val name: String,
val max_uses: Int,
val expiry_date: Instant,
val start_date: Instant
)
@@ -1,4 +1,4 @@
package dev.zxq5.chatapp.android.model package dev.zxq5.chatapp.android.api.model
sealed class LoginState { sealed class LoginState {
object Idle : LoginState() object Idle : LoginState()
@@ -1,13 +0,0 @@
package dev.zxq5.chatapp.android.api.model
import kotlinx.serialization.Serializable
import kotlin.time.ExperimentalTime
import kotlin.time.Instant
@Serializable
data class Message @OptIn(ExperimentalTime::class) constructor(
val user_id: Int,
val display_name: String,
val text: String,
val timestamp: Instant
)
@@ -3,9 +3,12 @@ package dev.zxq5.chatapp.android.api.model
import kotlinx.serialization.Serializable import kotlinx.serialization.Serializable
import kotlin.time.ExperimentalTime import kotlin.time.ExperimentalTime
import kotlin.time.Instant import kotlin.time.Instant
import kotlin.uuid.ExperimentalUuidApi
import kotlin.uuid.Uuid
@Serializable @Serializable
data class SendMessage @OptIn(ExperimentalTime::class) constructor( data class SendMessage @OptIn(ExperimentalTime::class, ExperimentalUuidApi::class) constructor(
val id: Uuid,
val user_id: Int, val user_id: Int,
val text: String, val text: String,
val timestamp: Instant val timestamp: Instant
@@ -3,10 +3,12 @@ package dev.zxq5.chatapp.android.core.data
import android.content.Context import android.content.Context
import android.content.SharedPreferences import android.content.SharedPreferences
import android.util.Base64 import android.util.Base64
import android.util.Log
import androidx.core.content.edit import androidx.core.content.edit
import androidx.security.crypto.EncryptedSharedPreferences import androidx.security.crypto.EncryptedSharedPreferences
import androidx.security.crypto.MasterKey import androidx.security.crypto.MasterKey
import org.json.JSONObject import org.json.JSONObject
import java.time.Instant
private const val KEY = "auth_token" private const val KEY = "auth_token"
private const val TWOFA_KEY = "twofa_enabled" private const val TWOFA_KEY = "twofa_enabled"
@@ -27,11 +29,37 @@ class TokenStore(appContext: Context) {
) )
} }
fun save(token: String) = fun save(token: String) {
prefs().edit { putString(KEY, token) } Log.d("TokenStore", "Saving token: $token")
prefs().edit { putString(KEY, token) }
}
fun get(): String? {
val ret = prefs().getString(KEY, null)
Log.d("TokenStore", "Retrieved token: $ret")
return ret
}
fun isExpired(): Boolean {
val token = get() ?: return true
return try {
val payload = token.split(".")[1]
val padded = payload + "==".take((4 - payload.length % 4) % 4)
val jsonString = String(Base64.decode(padded, Base64.URL_SAFE))
val json = JSONObject(jsonString)
if (json.has("exp")) {
val exp = json.getLong("exp")
val now = Instant.now().epochSecond
now >= exp
} else {
false // If no exp claim, assume not expired or handle differently
}
} catch (e: Exception) {
true // If we can't parse it, treat it as expired
}
}
fun get(): String? =
prefs().getString(KEY, null)
fun save2faEnabled( enabled: Boolean) = fun save2faEnabled( enabled: Boolean) =
prefs().edit { putBoolean(TWOFA_KEY, enabled) } prefs().edit { putBoolean(TWOFA_KEY, enabled) }
@@ -3,10 +3,10 @@ package dev.zxq5.chatapp.android.core.service
import android.app.Service import android.app.Service
import android.content.Context import android.content.Context
import android.content.Intent import android.content.Intent
import android.os.Build
import android.os.IBinder import android.os.IBinder
import android.util.Log import android.util.Log
import dev.zxq5.chatapp.android.ChatApplication import dev.zxq5.chatapp.android.ChatApplication
import dev.zxq5.chatapp.android.api.model.ChatEvent
import dev.zxq5.chatapp.android.data.repository.ChatRepository import dev.zxq5.chatapp.android.data.repository.ChatRepository
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
@@ -15,21 +15,17 @@ import kotlinx.coroutines.cancel
import kotlinx.coroutines.flow.catch import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
// core/service/MessageStreamService.kt
class MessageStreamService : Service() { class MessageStreamService : Service() {
private val serviceScope = CoroutineScope(SupervisorJob() + Dispatchers.IO) private val serviceScope = CoroutineScope(SupervisorJob() + Dispatchers.IO)
private lateinit var notificationService: NotificationService private lateinit var notificationService: NotificationService
private lateinit var chatRepository: ChatRepository private lateinit var chatRepository: ChatRepository
// which channel the user is currently looking at
// set by the ViewModel when the user opens/closes a channel
var activeChannelId: Long? = null var activeChannelId: Long? = null
set(value) { set(value) {
field = value field = value
Log.d("Service", "activeChannelId set to $value") Log.d("Service", "activeChannelId set to $value")
if (value != null) { if (value != null) {
// restart stream with new channel
currentStreamJob?.cancel() currentStreamJob?.cancel()
observeMessages() observeMessages()
} }
@@ -42,12 +38,8 @@ class MessageStreamService : Service() {
fun start(context: Context) { fun start(context: Context) {
val intent = Intent(context, MessageStreamService::class.java) val intent = Intent(context, MessageStreamService::class.java)
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
context.startForegroundService(intent)
} else {
context.startService(intent) context.startService(intent)
} }
}
fun stop(context: Context) { fun stop(context: Context) {
context.stopService(Intent(context, MessageStreamService::class.java)) context.stopService(Intent(context, MessageStreamService::class.java))
@@ -62,33 +54,29 @@ class MessageStreamService : Service() {
} }
override fun onStartCommand(intent: Intent?, flags: Int, startId: Int): Int { override fun onStartCommand(intent: Intent?, flags: Int, startId: Int): Int {
startForeground(
NotificationService.FOREGROUND_NOTIFICATION_ID,
notificationService.buildForegroundNotification()
)
observeMessages() observeMessages()
return START_STICKY // restart if killed return START_STICKY
} }
private fun observeMessages() { private fun observeMessages() {
val channelId = activeChannelId ?: chatRepository.getLastActiveChannel() val channelId = activeChannelId ?: chatRepository.getLastActiveChannel()
Log.d("Service", "observeMessages called, channelId=$channelId") if (channelId == null) return
if (channelId == null) {
Log.d("Service", "No channel to observe, waiting for switchChannel")
return
}
Log.d("Service", "Starting stream for channel $channelId")
currentStreamJob = serviceScope.launch { currentStreamJob = serviceScope.launch {
chatRepository.messageStream(channelId) chatRepository.eventStream(channelId)
.catch { e -> Log.e("Service", "Stream error", e) } .catch { e -> Log.e("Service", "Stream error", e) }
.collect { message -> .collect { event ->
if (!ChatApplication.AppState.isInForeground) { // no channel focused, always notify // Only show notification when an event (new message) is received
notificationService.showMessageNotification( // and the app is not in the foreground on this channel.
conversationId = activeChannelId.toString(), if (!ChatApplication.AppState.isInForeground || activeChannelId != channelId) {
senderName = message.display_name, when (event) {
messagePreview = message.text.take(80) is ChatEvent.SendMessage -> notificationService.showMessageNotification(
conversationId = channelId.toString(),
senderName = event.data.display_name,
messagePreview = event.data.text
) )
else -> {}
}
} }
} }
} }
@@ -15,51 +15,41 @@ class NotificationService(private val context: Context) {
companion object { companion object {
const val CHANNEL_ID = "messages" const val CHANNEL_ID = "messages"
const val FOREGROUND_NOTIFICATION_ID = 1 // ← this needs to exist const val SERVICE_CHANNEL_ID = "service"
const val FOREGROUND_NOTIFICATION_ID = 1
} }
private val manager = context.getSystemService(NotificationManager::class.java) private val manager = context.getSystemService(NotificationManager::class.java)
fun createChannels() { fun createForegroundNotification(): Notification {
// channel for new message notifications val intent = Intent(context, MainActivity::class.java).apply {
val messageChannel = NotificationChannel( flags = Intent.FLAG_ACTIVITY_NEW_TASK or Intent.FLAG_ACTIVITY_CLEAR_TASK
CHANNEL_ID,
"Messages",
NotificationManager.IMPORTANCE_HIGH
).apply {
enableVibration(true)
} }
// channel for the persistent foreground service notification val pendingIntent = PendingIntent.getActivity(
// low importance so it doesn't make noise context,
val serviceChannel = NotificationChannel( 0,
"service", intent,
"Background connection", PendingIntent.FLAG_IMMUTABLE
NotificationManager.IMPORTANCE_LOW
) )
val mgr = context.getSystemService(NotificationManager::class.java) return NotificationCompat.Builder(context, SERVICE_CHANNEL_ID)
mgr.createNotificationChannel(messageChannel)
mgr.createNotificationChannel(serviceChannel)
}
fun buildForegroundNotification(): Notification {
return NotificationCompat.Builder(context, "service")
.setSmallIcon(R.drawable.ic_notification) .setSmallIcon(R.drawable.ic_notification)
.setContentTitle("chatapp") .setContentTitle("Chat App")
.setContentText("Connected") .setContentText("Connecting to message stream...")
.setPriority(NotificationCompat.PRIORITY_LOW)
.setCategory(NotificationCompat.CATEGORY_SERVICE)
.setContentIntent(pendingIntent)
.setOngoing(true) .setOngoing(true)
.setSilent(true)
.build() .build()
} }
fun showMessageNotification( fun showMessageNotification(
conversationId: String, conversationId: String,
senderName: String, senderName: String,
messagePreview: String, // for E2E this would be "New message" — no plaintext messagePreview: String,
notificationId: Int = conversationId.hashCode() notificationId: Int = conversationId.hashCode()
) { ) {
// intent that opens the app to the right conversation when tapped
val intent = Intent(context, MainActivity::class.java).apply { val intent = Intent(context, MainActivity::class.java).apply {
flags = Intent.FLAG_ACTIVITY_SINGLE_TOP flags = Intent.FLAG_ACTIVITY_SINGLE_TOP
putExtra("conversation_id", conversationId) putExtra("conversation_id", conversationId)
@@ -72,13 +62,13 @@ class NotificationService(private val context: Context) {
PendingIntent.FLAG_UPDATE_CURRENT or PendingIntent.FLAG_IMMUTABLE PendingIntent.FLAG_UPDATE_CURRENT or PendingIntent.FLAG_IMMUTABLE
) )
val notification = NotificationCompat.Builder(context, "messages") val notification = NotificationCompat.Builder(context, CHANNEL_ID)
.setSmallIcon(R.drawable.ic_notification) .setSmallIcon(R.drawable.ic_notification)
.setContentTitle(senderName) .setContentTitle(senderName)
.setContentText(messagePreview) .setContentText(messagePreview)
.setPriority(NotificationCompat.PRIORITY_HIGH) .setPriority(NotificationCompat.PRIORITY_HIGH)
.setContentIntent(pendingIntent) .setContentIntent(pendingIntent)
.setAutoCancel(true) // dismiss on tap .setAutoCancel(true)
.build() .build()
manager.notify(notificationId, notification) manager.notify(notificationId, notification)
@@ -56,6 +56,10 @@ class AuthRepository(
fun getAuthState(): AuthState { fun getAuthState(): AuthState {
val token = tokenStore.get() ?: return AuthState.Unauthenticated val token = tokenStore.get() ?: return AuthState.Unauthenticated
if (tokenStore.isExpired()) {
tokenStore.clear()
return AuthState.Unauthenticated
}
return when (getScopeFromToken(token)) { return when (getScopeFromToken(token)) {
TokenScope.FULL -> AuthState.Authenticated TokenScope.FULL -> AuthState.Authenticated
TokenScope.TOTP_PENDING -> AuthState.AwaitingTotp TokenScope.TOTP_PENDING -> AuthState.AwaitingTotp
@@ -1,6 +1,7 @@
package dev.zxq5.chatapp.android.data.repository package dev.zxq5.chatapp.android.data.repository
import dev.zxq5.chatapp.android.api.ChatClient import dev.zxq5.chatapp.android.api.ChatClient
import dev.zxq5.chatapp.android.api.model.ChatEvent
import dev.zxq5.chatapp.android.core.data.TokenStore import dev.zxq5.chatapp.android.core.data.TokenStore
import dev.zxq5.chatapp.android.api.model.Message import dev.zxq5.chatapp.android.api.model.Message
import dev.zxq5.chatapp.android.api.model.SpaceDto import dev.zxq5.chatapp.android.api.model.SpaceDto
@@ -43,8 +44,8 @@ class ChatRepository(private val tokenStore: TokenStore) {
getChatClient()?.sendMessage(channelId, userId, text) getChatClient()?.sendMessage(channelId, userId, text)
} }
fun messageStream(channelId: Long): Flow<Message> { fun eventStream(channelId: Long): Flow<ChatEvent> {
_lastActiveChannel = channelId _lastActiveChannel = channelId
return getChatClient()?.messageStream(channelId) ?: emptyFlow() return getChatClient()?.eventStream(channelId) ?: emptyFlow()
} }
} }
@@ -2,6 +2,7 @@ package dev.zxq5.chatapp.android.data.repository
import dev.zxq5.chatapp.android.api.model.QrResponse import dev.zxq5.chatapp.android.api.model.QrResponse
import dev.zxq5.chatapp.android.api.SettingsClient import dev.zxq5.chatapp.android.api.SettingsClient
import dev.zxq5.chatapp.android.api.model.InviteRequest
import dev.zxq5.chatapp.android.api.model.TotpStatus import dev.zxq5.chatapp.android.api.model.TotpStatus
import dev.zxq5.chatapp.android.core.data.TokenStore import dev.zxq5.chatapp.android.core.data.TokenStore
import dev.zxq5.chatapp.android.core.error.ApiResult import dev.zxq5.chatapp.android.core.error.ApiResult
@@ -25,6 +26,10 @@ class SettingsRepository(private val tokenStore: TokenStore) {
_lastToken = null _lastToken = null
} }
suspend fun createInvite(request: InviteRequest): ApiResult<String> {
return getSettingsClient()?.createInvite(request) ?: ApiResult.NetworkError("Not authenticated")
}
suspend fun getTotpQr(password: String): ApiResult<QrResponse?> { suspend fun getTotpQr(password: String): ApiResult<QrResponse?> {
val settingsClient = getSettingsClient() ?: return ApiResult.NetworkError("Not authenticated") val settingsClient = getSettingsClient() ?: return ApiResult.NetworkError("Not authenticated")
return settingsClient.getTotpQr(password) return settingsClient.getTotpQr(password)
@@ -3,7 +3,7 @@ package dev.zxq5.chatapp.android.feature.auth
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.collectAsState import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import dev.zxq5.chatapp.android.model.LoginState import dev.zxq5.chatapp.android.api.model.LoginState
@Composable @Composable
fun AuthScreen(viewModel: AuthViewModel) { fun AuthScreen(viewModel: AuthViewModel) {
@@ -7,7 +7,7 @@ import dev.zxq5.chatapp.android.data.repository.AuthRepository
import dev.zxq5.chatapp.android.data.repository.LoginResult import dev.zxq5.chatapp.android.data.repository.LoginResult
import dev.zxq5.chatapp.android.data.repository.SignupResult import dev.zxq5.chatapp.android.data.repository.SignupResult
import dev.zxq5.chatapp.android.data.repository.AuthState import dev.zxq5.chatapp.android.data.repository.AuthState
import dev.zxq5.chatapp.android.model.LoginState import dev.zxq5.chatapp.android.api.model.LoginState
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
@@ -28,7 +28,7 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.Color
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import dev.zxq5.chatapp.android.model.LoginState import dev.zxq5.chatapp.android.api.model.LoginState
import dev.zxq5.chatapp.android.ui.components.TextField import dev.zxq5.chatapp.android.ui.components.TextField
@Composable @Composable
@@ -28,7 +28,7 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.Color
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import dev.zxq5.chatapp.android.model.LoginState import dev.zxq5.chatapp.android.api.model.LoginState
import dev.zxq5.chatapp.android.ui.components.TextField import dev.zxq5.chatapp.android.ui.components.TextField
@Composable @Composable
@@ -1,20 +1,25 @@
@file:OptIn(ExperimentalUuidApi::class)
package dev.zxq5.chatapp.android.feature.chat package dev.zxq5.chatapp.android.feature.chat
import android.util.Log import android.util.Log
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import dev.zxq5.chatapp.android.api.model.Channel import dev.zxq5.chatapp.android.api.model.ChatEvent
import dev.zxq5.chatapp.android.data.repository.ChatRepository import dev.zxq5.chatapp.android.data.repository.ChatRepository
import dev.zxq5.chatapp.android.api.model.Message import dev.zxq5.chatapp.android.api.model.Message
import dev.zxq5.chatapp.android.api.model.Space
import dev.zxq5.chatapp.android.api.model.SpaceDto import dev.zxq5.chatapp.android.api.model.SpaceDto
import dev.zxq5.chatapp.android.core.service.MessageStreamService import dev.zxq5.chatapp.android.core.service.MessageStreamService
import io.ktor.client.plugins.ResponseException
import io.ktor.http.HttpStatusCode
import kotlinx.coroutines.Job import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.catch import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.update import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlin.time.ExperimentalTime
import kotlin.uuid.ExperimentalUuidApi
class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() { class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
@@ -41,6 +46,8 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
private var streamJob: Job? = null private var streamJob: Job? = null
var onUnauthorized: (() -> Unit)? = null
init { init {
_currentUserId.value = chatRepository.getUserId() _currentUserId.value = chatRepository.getUserId()
observeChannel() observeChannel()
@@ -49,6 +56,7 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
fun loadAccessibleChannels() { fun loadAccessibleChannels() {
_error.value = null _error.value = null
_currentUserId.value = chatRepository.getUserId()
viewModelScope.launch { viewModelScope.launch {
runCatching { runCatching {
chatRepository.getAccessibleChannels() chatRepository.getAccessibleChannels()
@@ -56,11 +64,16 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
_spaces.value = data _spaces.value = data
}.onFailure { e -> }.onFailure { e ->
Log.e("Chat", "Failed to load spaces", e) Log.e("Chat", "Failed to load spaces", e)
if (e is ResponseException && e.response.status == HttpStatusCode.Unauthorized) {
onUnauthorized?.invoke()
} else {
_error.value = "Failed to load channels: ${e.message}" _error.value = "Failed to load channels: ${e.message}"
} }
} }
} }
}
@OptIn(ExperimentalTime::class)
private fun observeChannel() { private fun observeChannel() {
viewModelScope.launch { viewModelScope.launch {
_channelId.collect { id -> _channelId.collect { id ->
@@ -69,13 +82,40 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
_channelError.value = null _channelError.value = null
if (id != null) { if (id != null) {
streamJob = launch { streamJob = launch {
chatRepository.messageStream(id) chatRepository.eventStream(id)
.catch { e -> .catch { e ->
Log.e("Chat", "Stream error", e) Log.e("Chat", "Stream error", e)
if (e is ResponseException && e.response.status == HttpStatusCode.Unauthorized) {
onUnauthorized?.invoke()
} else {
_channelError.value = "Connection lost: ${e.message}" _channelError.value = "Connection lost: ${e.message}"
} }
.collect { message -> }
_messages.update { it + message } .collect { event ->
when (event) {
is ChatEvent.SendMessage -> {
_messages.update { it + event.data }
}
is ChatEvent.EditMessage -> {
_messages.update { messages ->
messages.map {
if (it.id == event.data.id) event.data.message
else it
}
}
}
is ChatEvent.MessageAppendContent -> {
_messages.update { messages ->
messages.map { msg ->
if (msg.id == event.data.id) {
msg.copy(text = msg.text + event.data.content)
} else {
msg
}
}
}
}
}
} }
} }
} }
@@ -108,15 +148,20 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
) )
}.onFailure { e -> }.onFailure { e ->
Log.e("Chat", "Send message error", e) Log.e("Chat", "Send message error", e)
if (e is ResponseException && e.response.status == HttpStatusCode.Unauthorized) {
onUnauthorized?.invoke()
} else {
_channelError.value = "Failed to send message" _channelError.value = "Failed to send message"
} }
} }
} }
}
fun clearChat() { fun clearChat() {
_messages.value = emptyList() _messages.value = emptyList()
_channelId.value = null _channelId.value = null
_currentUserId.value = null _currentUserId.value = null
_spaces.value = emptyList()
_error.value = null _error.value = null
_channelError.value = null _channelError.value = null
streamJob?.cancel() streamJob?.cancel()
@@ -1,3 +1,5 @@
@file:OptIn(ExperimentalUuidApi::class)
package dev.zxq5.chatapp.android.feature.chat package dev.zxq5.chatapp.android.feature.chat
import androidx.compose.foundation.BorderStroke import androidx.compose.foundation.BorderStroke
@@ -65,6 +67,7 @@ import dev.zxq5.chatapp.android.api.model.Message
import java.text.DateFormat import java.text.DateFormat
import java.util.Date import java.util.Date
import kotlin.time.ExperimentalTime import kotlin.time.ExperimentalTime
import kotlin.uuid.ExperimentalUuidApi
@Composable @Composable
fun ChatScreen( fun ChatScreen(
@@ -277,7 +280,7 @@ fun MessageScreen(channelId: Long, viewModel: ChatViewModel, onBack: () -> Unit)
modifier = Modifier.weight(1f).padding(horizontal = 16.dp), modifier = Modifier.weight(1f).padding(horizontal = 16.dp),
verticalArrangement = Arrangement.spacedBy(10.dp) verticalArrangement = Arrangement.spacedBy(10.dp)
) { ) {
items(messages) { message -> items(messages, key = { it.id }) { message ->
MessageBubble(message, currentUserId) MessageBubble(message, currentUserId)
} }
item { Spacer(Modifier.height(10.dp)) } item { Spacer(Modifier.height(10.dp)) }
@@ -378,7 +381,7 @@ fun MessageBubble(message: Message, currentUserId: Int?) {
horizontalAlignment = if (isMe) Alignment.End else Alignment.Start horizontalAlignment = if (isMe) Alignment.End else Alignment.Start
) { ) {
Surface( Surface(
color = if (isMe) MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f) else MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.2f), color = if (isMe) MaterialTheme.colorScheme.surfaceVariant else MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.2f),
shape = RoundedCornerShape( shape = RoundedCornerShape(
topStart = 14.dp, topStart = 14.dp,
topEnd = 14.dp, topEnd = 14.dp,
@@ -388,14 +391,7 @@ fun MessageBubble(message: Message, currentUserId: Int?) {
border = border(0.5.dp, MaterialTheme.colorScheme.outlineVariant.copy(alpha = 0.5f)) border = border(0.5.dp, MaterialTheme.colorScheme.outlineVariant.copy(alpha = 0.5f))
) { ) {
Column(modifier = Modifier.padding(horizontal = 11.dp, vertical = 8.dp)) { Column(modifier = Modifier.padding(horizontal = 11.dp, vertical = 8.dp)) {
if (!isMe) {
Text(
message.display_name?.lowercase() ?: "unknown",
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.primary.copy(alpha = 0.7f),
modifier = Modifier.padding(bottom = 2.dp)
)
}
Text( Text(
text = message.text, text = message.text,
style = MaterialTheme.typography.bodyLarge, style = MaterialTheme.typography.bodyLarge,
@@ -403,10 +399,11 @@ fun MessageBubble(message: Message, currentUserId: Int?) {
) )
} }
} }
Text( Text(
text = time, text = if (!isMe) message.display_name.lowercase() + " . " + time else time,
style = MaterialTheme.typography.labelSmall, style = MaterialTheme.typography.labelMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.3f), color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f),
modifier = Modifier.padding(horizontal = 4.dp, vertical = 2.dp) modifier = Modifier.padding(horizontal = 4.dp, vertical = 2.dp)
) )
} }
@@ -2,6 +2,7 @@ package dev.zxq5.chatapp.android.feature.settings
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import dev.zxq5.chatapp.android.api.model.InviteRequest
import dev.zxq5.chatapp.android.api.model.QrResponse import dev.zxq5.chatapp.android.api.model.QrResponse
import dev.zxq5.chatapp.android.core.error.ApiResult import dev.zxq5.chatapp.android.core.error.ApiResult
import dev.zxq5.chatapp.android.data.repository.SettingsRepository import dev.zxq5.chatapp.android.data.repository.SettingsRepository
@@ -27,6 +28,9 @@ class SettingsViewModel(private val settingsRepository: SettingsRepository) : Vi
private val _isSuccessState = MutableStateFlow<Map<String, Boolean>>(emptyMap()) private val _isSuccessState = MutableStateFlow<Map<String, Boolean>>(emptyMap())
val isSuccessState: StateFlow<Map<String, Boolean>> = _isSuccessState val isSuccessState: StateFlow<Map<String, Boolean>> = _isSuccessState
private val _lastInviteCode = MutableStateFlow<String?>(null)
val lastInviteCode: StateFlow<String?> = _lastInviteCode
fun clearMessages() { fun clearMessages() {
_settingsError.value = null _settingsError.value = null
_totpError.value = null _totpError.value = null
@@ -40,6 +44,20 @@ class SettingsViewModel(private val settingsRepository: SettingsRepository) : Vi
} }
} }
fun createInvite(request: InviteRequest) {
viewModelScope.launch {
_settingsError.value = null
when (val result = settingsRepository.createInvite(request)) {
is ApiResult.Success -> {
_lastInviteCode.value = result.data
triggerSuccess("invite")
}
is ApiResult.HttpError -> _settingsError.value = result.message
is ApiResult.NetworkError -> _settingsError.value = result.message
}
}
}
fun fetchTotpStatus() { fun fetchTotpStatus() {
viewModelScope.launch { viewModelScope.launch {
when (val result = settingsRepository.getTotpStatus()) { when (val result = settingsRepository.getTotpStatus()) {
@@ -23,11 +23,14 @@ import androidx.compose.foundation.text.KeyboardOptions
import androidx.compose.foundation.verticalScroll import androidx.compose.foundation.verticalScroll
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.filled.ArrowBack import androidx.compose.material.icons.automirrored.filled.ArrowBack
import androidx.compose.material.icons.filled.ContentCopy
import androidx.compose.material.icons.filled.KeyboardArrowDown import androidx.compose.material.icons.filled.KeyboardArrowDown
import androidx.compose.material.icons.filled.KeyboardArrowUp import androidx.compose.material.icons.filled.KeyboardArrowUp
import androidx.compose.material3.Button import androidx.compose.material3.Button
import androidx.compose.material3.ButtonDefaults import androidx.compose.material3.ButtonDefaults
import androidx.compose.material3.CircularProgressIndicator import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.DatePicker
import androidx.compose.material3.DatePickerDialog
import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.HorizontalDivider import androidx.compose.material3.HorizontalDivider
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
@@ -37,8 +40,10 @@ import androidx.compose.material3.OutlinedTextField
import androidx.compose.material3.OutlinedTextFieldDefaults import androidx.compose.material3.OutlinedTextFieldDefaults
import androidx.compose.material3.Scaffold import androidx.compose.material3.Scaffold
import androidx.compose.material3.Text import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.material3.TopAppBar import androidx.compose.material3.TopAppBar
import androidx.compose.material3.TopAppBarDefaults import androidx.compose.material3.TopAppBarDefaults
import androidx.compose.material3.rememberDatePickerState
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState import androidx.compose.runtime.collectAsState
@@ -57,9 +62,15 @@ import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp import androidx.compose.ui.unit.sp
import android.util.Base64 import android.util.Base64
import android.graphics.BitmapFactory import android.graphics.BitmapFactory
import androidx.compose.ui.platform.LocalClipboardManager
import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.text.style.TextAlign
import dev.zxq5.chatapp.android.api.model.InviteRequest
import kotlin.time.Duration.Companion.days
import kotlin.time.ExperimentalTime
import kotlin.time.Instant
@OptIn(ExperimentalMaterial3Api::class) @OptIn(ExperimentalMaterial3Api::class, ExperimentalTime::class)
@Composable @Composable
fun SettingsScreen( fun SettingsScreen(
viewModel: SettingsViewModel, viewModel: SettingsViewModel,
@@ -70,6 +81,7 @@ fun SettingsScreen(
val settingsError by viewModel.settingsError.collectAsState() val settingsError by viewModel.settingsError.collectAsState()
val isSuccessState by viewModel.isSuccessState.collectAsState() val isSuccessState by viewModel.isSuccessState.collectAsState()
val totpError by viewModel.totpError.collectAsState() val totpError by viewModel.totpError.collectAsState()
val lastInviteCode by viewModel.lastInviteCode.collectAsState()
LaunchedEffect(Unit) { LaunchedEffect(Unit) {
viewModel.clearMessages() viewModel.clearMessages()
@@ -274,6 +286,120 @@ fun SettingsScreen(
} }
} }
SettingsSection(title = "invite") {
var inviteName by remember { mutableStateOf("") }
var maxUses by remember { mutableStateOf("1") }
val clipboardManager = LocalClipboardManager.current
var showDatePicker by remember { mutableStateOf(false) }
val datePickerState = rememberDatePickerState(
initialSelectedDateMillis = System.currentTimeMillis() + 7.days.inWholeMilliseconds
)
Text("create invite token", style = MaterialTheme.typography.bodyMedium, modifier = Modifier.padding(bottom = 8.dp))
OutlinedTextField(
value = inviteName,
onValueChange = { inviteName = it },
label = { Text("name") },
modifier = Modifier.fillMaxWidth(),
shape = RoundedCornerShape(8.dp)
)
Spacer(Modifier.height(8.dp))
OutlinedTextField(
value = maxUses,
onValueChange = { if (it.all { c -> c.isDigit() }) maxUses = it },
label = { Text("max uses") },
keyboardOptions = KeyboardOptions(keyboardType = KeyboardType.Number),
modifier = Modifier.fillMaxWidth(),
shape = RoundedCornerShape(8.dp)
)
Spacer(Modifier.height(8.dp))
OutlinedTextField(
value = datePickerState.selectedDateMillis?.let { Instant.fromEpochMilliseconds(it).toString().substringBefore("T") } ?: "",
onValueChange = {},
label = { Text("expiry date") },
readOnly = true,
trailingIcon = {
IconButton(onClick = { showDatePicker = true }) {
Icon(Icons.Default.KeyboardArrowDown, contentDescription = "Select Date")
}
},
modifier = Modifier.fillMaxWidth().clickable { showDatePicker = true },
shape = RoundedCornerShape(8.dp)
)
if (showDatePicker) {
DatePickerDialog(
onDismissRequest = { showDatePicker = false },
confirmButton = {
TextButton(onClick = { showDatePicker = false }) {
Text("ok")
}
}
) {
DatePicker(state = datePickerState)
}
}
Spacer(Modifier.height(12.dp))
SuccessButton(
onClick = {
val nowMs = System.currentTimeMillis()
val expiryMs = datePickerState.selectedDateMillis ?: (nowMs + 7.days.inWholeMilliseconds)
viewModel.createInvite(
InviteRequest(
name = inviteName,
max_uses = maxUses.toIntOrNull() ?: 1,
start_date = Instant.fromEpochMilliseconds(nowMs),
expiry_date = Instant.fromEpochMilliseconds(expiryMs)
)
)
},
label = "generate invite",
isSuccess = isSuccessState["invite"] == true,
enabled = inviteName.isNotBlank(),
modifier = Modifier.fillMaxWidth()
)
if (lastInviteCode != null) {
Spacer(Modifier.height(16.dp))
Row(
modifier = Modifier
.fillMaxWidth()
.background(MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.3f), RoundedCornerShape(8.dp))
.padding(12.dp),
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.SpaceBetween
) {
Text(
text = lastInviteCode!!,
style = MaterialTheme.typography.bodyLarge,
modifier = Modifier.weight(1f)
)
IconButton(onClick = {
clipboardManager.setText(AnnotatedString(lastInviteCode!!))
}) {
Icon(Icons.Default.ContentCopy, contentDescription = "Copy", modifier = Modifier.size(20.dp))
}
}
}
}
SettingsSection(title = "session") {
Button(
onClick = onLogout,
modifier = Modifier.fillMaxWidth(),
shape = RoundedCornerShape(8.dp),
colors = ButtonDefaults.buttonColors(containerColor = Color.White, contentColor = Color.Black)
) {
Text("logout")
}
}
SettingsSection(title = "danger zone", color = Color.Red.copy(alpha = 0.7f)) { SettingsSection(title = "danger zone", color = Color.Red.copy(alpha = 0.7f)) {
var deletePassword by remember { mutableStateOf("") } var deletePassword by remember { mutableStateOf("") }
var deleteTotp by remember { mutableStateOf("") } var deleteTotp by remember { mutableStateOf("") }
@@ -337,18 +463,6 @@ fun SettingsScreen(
} }
} }
SettingsSection(title = "session") {
Spacer(Modifier.height(16.dp))
Button(
onClick = onLogout,
modifier = Modifier.fillMaxWidth(),
shape = RoundedCornerShape(8.dp),
colors = ButtonDefaults.buttonColors(containerColor = Color.White, contentColor = Color.Black)
) {
Text("logout")
}
}
if (settingsError != null) { if (settingsError != null) {
Text(settingsError!!, color = Color.Red, style = MaterialTheme.typography.bodySmall, modifier = Modifier.padding(top = 8.dp)) Text(settingsError!!, color = Color.Red, style = MaterialTheme.typography.bodySmall, modifier = Modifier.padding(top = 8.dp))
} }
@@ -457,6 +571,7 @@ fun SuccessButton(
} }
} }
@OptIn(ExperimentalTime::class)
@Composable @Composable
fun TwoFactorSetup( fun TwoFactorSetup(
qrCodeBase64: String?, qrCodeBase64: String?,
@@ -511,15 +626,13 @@ fun TwoFactorSetup(
Text(error.lowercase(), color = Color.Red, style = MaterialTheme.typography.labelSmall, modifier = Modifier.padding(top = 8.dp)) Text(error.lowercase(), color = Color.Red, style = MaterialTheme.typography.labelSmall, modifier = Modifier.padding(top = 8.dp))
} }
Spacer(Modifier.height(24.dp)) Spacer(Modifier.height(16.dp))
SuccessButton(
Button( onClick = { onConfirm(code) },
onClick = { if (code.length == 6) onConfirm(code) }, label = "verify and enable",
isSuccess = false, // Managed by parent
enabled = code.length == 6, enabled = code.length == 6,
modifier = Modifier.fillMaxWidth(), modifier = Modifier.fillMaxWidth()
shape = RoundedCornerShape(8.dp) )
) {
Text("confirm code")
}
} }
} }
@@ -2,12 +2,14 @@ package dev.zxq5.chatapp.android.ui.components
import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.foundation.text.KeyboardOptions
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.OutlinedTextField import androidx.compose.material3.OutlinedTextField
import androidx.compose.material3.OutlinedTextFieldDefaults import androidx.compose.material3.OutlinedTextFieldDefaults
import androidx.compose.material3.Text import androidx.compose.material3.Text
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.text.input.KeyboardType
import androidx.compose.ui.text.input.PasswordVisualTransformation import androidx.compose.ui.text.input.PasswordVisualTransformation
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
@@ -31,6 +33,11 @@ fun TextField(
modifier = Modifier.fillMaxWidth(), modifier = Modifier.fillMaxWidth(),
singleLine = true, singleLine = true,
textStyle = MaterialTheme.typography.bodyLarge, textStyle = MaterialTheme.typography.bodyLarge,
keyboardOptions = if (isPassword) {
KeyboardOptions(keyboardType = KeyboardType.Password)
} else {
KeyboardOptions.Default
},
visualTransformation = if (isPassword) PasswordVisualTransformation() else androidx.compose.ui.text.input.VisualTransformation.None, visualTransformation = if (isPassword) PasswordVisualTransformation() else androidx.compose.ui.text.input.VisualTransformation.None,
shape = RoundedCornerShape(8.dp), shape = RoundedCornerShape(8.dp),
colors = OutlinedTextFieldDefaults.colors( colors = OutlinedTextFieldDefaults.colors(
@@ -40,6 +47,6 @@ fun TextField(
unfocusedBorderColor = MaterialTheme.colorScheme.outline, unfocusedBorderColor = MaterialTheme.colorScheme.outline,
focusedTextColor = MaterialTheme.colorScheme.onSurface, focusedTextColor = MaterialTheme.colorScheme.onSurface,
unfocusedTextColor = MaterialTheme.colorScheme.onSurface unfocusedTextColor = MaterialTheme.colorScheme.onSurface
) ),
) )
} }
@@ -2,7 +2,7 @@ package dev.zxq5.chatapp.android.ui.theme
import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.Color
val Black = Color(0xFF0A0A0A) val Black = Color(0xFF000000)
val DarkGrey = Color(0xFF0D0D0D) val DarkGrey = Color(0xFF0D0D0D)
val Grey = Color(0xFF141414) val Grey = Color(0xFF141414)
val LightGrey = Color(0xFF1E1E1E) val LightGrey = Color(0xFF1E1E1E)
+6
View File
@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="uk.co.ben_gibson.git.link.SettingsState">
<option name="host" value="e0f86390-1091-4871-8aeb-f534fbc99cf0" />
</component>
</project>
+1 -1
View File
@@ -1,7 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="DataSourcePerFileMappings"> <component name="DataSourcePerFileMappings">
<file url="file://$PROJECT_DIR$/sql/schema.sql" value="b14acf5d-6750-469b-8aea-59c8343eb11c" /> <file url="file://$PROJECT_DIR$/sql/test.sql" value="b14acf5d-6750-469b-8aea-59c8343eb11c" />
<file url="file://$PROJECT_DIR$/src/repo/user_repo.rs" value="b14acf5d-6750-469b-8aea-59c8343eb11c" /> <file url="file://$PROJECT_DIR$/src/repo/user_repo.rs" value="b14acf5d-6750-469b-8aea-59c8343eb11c" />
</component> </component>
</project> </project>
+2 -1
View File
@@ -1,7 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="SqlDialectMappings"> <component name="SqlDialectMappings">
<file url="file://$PROJECT_DIR$/sql/schema.sql" dialect="PostgreSQL" /> <file url="file://$PROJECT_DIR$/migrations/20260412200102_message_id_to_uuid.sql" dialect="PostgreSQL" />
<file url="file://$PROJECT_DIR$/sql/test.sql" dialect="PostgreSQL" />
<file url="PROJECT" dialect="PostgreSQL" /> <file url="PROJECT" dialect="PostgreSQL" />
</component> </component>
</project> </project>
+3 -3
View File
@@ -12,7 +12,7 @@ image = "0.25.8"
jsonwebtoken = { version = "10.3.0", features = ["rust_crypto"] } jsonwebtoken = { version = "10.3.0", features = ["rust_crypto"] }
rand = "0.8" rand = "0.8"
redis = { version = "0.25.4", features = ["tokio-comp"] } redis = { version = "0.25.4", features = ["tokio-comp"] }
reqwest = { version = "0.12.23", features = ["json"] } reqwest = { version = "0.12.23", features = ["json", "stream"] }
rocket = { version = "0.5.1", features = ["json", "secrets"] } rocket = { version = "0.5.1", features = ["json", "secrets"] }
rocket_cors = "0.6.0" rocket_cors = "0.6.0"
rocket_db_pools = { version = "0.2.0", features = ["deadpool_redis", "sqlx_macros", "sqlx_postgres"] } rocket_db_pools = { version = "0.2.0", features = ["deadpool_redis", "sqlx_macros", "sqlx_postgres"] }
@@ -20,11 +20,11 @@ rocket_dyn_templates = { version = "0.2.0", features = ["tera"] }
serde = { version = "1.0.228", features = ["derive"] } serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0.145" serde_json = "1.0.145"
sha2 = "0.10.9" sha2 = "0.10.9"
sqlx = { version = "0.7.4", features = ["chrono", "macros", "postgres", "time"] } sqlx = { version = "0.7.4", features = ["chrono", "macros", "postgres", "time", "uuid"] }
tokio = { version = "1.47.1", features = ["full"] } tokio = { version = "1.47.1", features = ["full"] }
totp-rs = { version = "5.7.0", features = ["gen_secret", "qr", "rand"] } totp-rs = { version = "5.7.0", features = ["gen_secret", "qr", "rand"] }
tracing = "0.1.44" tracing = "0.1.44"
uuid = { version = "1.18.1", features = ["v4"] } uuid = { version = "1.18.1", features = ["serde", "v4"] }
thiserror = "1.0.69" thiserror = "1.0.69"
utoipa = { version = "5.4.0", features = ["rocket_extras", "chrono"] } utoipa = { version = "5.4.0", features = ["rocket_extras", "chrono"] }
clap = { version = "4.5", features = ["derive"] } clap = { version = "4.5", features = ["derive"] }
+1 -1
View File
@@ -1,7 +1,7 @@
[debug] [debug]
secret_key = "yYhvCGnRh/TrcHtB8sZqCFifrVmJxoKFLBYw/WWBZeU=" secret_key = "yYhvCGnRh/TrcHtB8sZqCFifrVmJxoKFLBYw/WWBZeU="
address = "0.0.0.0" address = "0.0.0.0"
port = 8000 port = 8080
[debug.databases.postgres_db] [debug.databases.postgres_db]
url = "postgresql://chatapp:chatapp@100.118.108.58:5432/chatapp_dev" url = "postgresql://chatapp:chatapp@100.118.108.58:5432/chatapp_dev"
+2
View File
@@ -8,6 +8,8 @@ services:
- redis - redis
env_file: env_file:
- .env - .env
environment:
- RELEASE_MODE=1
redis: redis:
container_name: chatapp_redis container_name: chatapp_redis
@@ -0,0 +1,9 @@
ALTER TABLE attachments DROP CONSTRAINT attachments_message_id_fkey;
ALTER TABLE messages ALTER COLUMN id DROP DEFAULT;
ALTER TABLE messages ALTER COLUMN id TYPE uuid USING gen_random_uuid();
ALTER TABLE messages ALTER COLUMN id SET DEFAULT gen_random_uuid();
ALTER TABLE attachments ALTER COLUMN message_id TYPE uuid USING gen_random_uuid();
ALTER TABLE attachments ADD CONSTRAINT attachments_message_id_fkey
FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE CASCADE;
+17
View File
@@ -0,0 +1,17 @@
WITH space1 AS (
INSERT INTO spaces (name, description, owner_id)
VALUES ('general', 'Boring chat idk', 1)
RETURNING id
),
space2 AS (
INSERT INTO spaces (name, description, owner_id)
VALUES ('Gaming', 'we lose games', 1)
RETURNING id
)
INSERT INTO channels (name, description, space_id)
SELECT 'General', 'General chat', id FROM space1 UNION ALL
SELECT 'Coding', 'Coding stuff', id FROM space1 UNION ALL
SELECT 'AI', '"/ask" here pls :)', id FROM space1 UNION ALL
SELECT 'The Game', '(You lost)', id FROM space2 UNION ALL
SELECT 'Backrooms', 'Beware of Smilers', id FROM space2 UNION ALL
SELECT 'SE', 'Space/Software engineering.', id FROM space2;
+9 -16
View File
@@ -9,15 +9,7 @@ use rocket::{Shutdown, State, ___internal_EventStream as EventStream};
use sqlx::FromRow; use sqlx::FromRow;
use tokio::select; use tokio::select;
use tokio::sync::broadcast; use tokio::sync::broadcast;
use crate::model::event::{ChatEvent, ChatMsg};
/// ---------- Rocket routes ----------
#[derive(Debug, Serialize, Deserialize, Clone, FromRow)]
pub struct ChatMsg {
pub display_name: Option<String>,
pub user_id: i64,
pub text: String,
pub timestamp: DateTime<Utc>,
}
#[post("/chat/<channel_id>", format = "json", data = "<msg>")] #[post("/chat/<channel_id>", format = "json", data = "<msg>")]
pub async fn post_message( pub async fn post_message(
@@ -36,24 +28,25 @@ pub async fn event_stream(
mut shutdown: Shutdown, mut shutdown: Shutdown,
channel_id: i64, channel_id: i64,
) -> ApiResult<EventStream![]> { ) -> ApiResult<EventStream![]> {
let messages = chat.get_messages(channel_id, 100) let messages = chat.fetch_latest_messages_desc(channel_id, 100)
.await?; // if get message returned err, inform user. .await?; // if get message returned err, inform user.
let mut rx = chat.subscribe(channel_id).await; let mut rx = chat.subscribe(channel_id).await;
let id = s.uid; let id = s.uid;
Ok(EventStream! { Ok(EventStream! {
for msg in messages { for msg in messages.into_iter().rev() {
yield Event::json(&msg); // tracing::info!("sending: {:?}", serde_json::to_string(&ChatEvent::SendMessage(msg.clone())).unwrap());
yield Event::json(&ChatEvent::SendMessage(msg));
} }
loop { loop {
select!{ select!{
_ = &mut shutdown => break, // exit early on shutdown _ = &mut shutdown => break, // exit early on shutdown
msg = rx.recv() => match msg { event = rx.recv() => match event {
Ok(msg) => { Ok(event) => {
tracing::info!("yielding message!"); // tracing::info!("yielding event: {event:?}");
yield Event::json(&msg) yield Event::json(&event)
}, },
Err(broadcast::error::RecvError::Lagged(n)) => { Err(broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!("Receiver lagging on channel {channel_id} by {n} events",); tracing::warn!("Receiver lagging on channel {channel_id} by {n} events",);
+27
View File
@@ -0,0 +1,27 @@
use rocket::serde::{Deserialize, Serialize};
use sqlx::FromRow;
use chrono::{DateTime, Utc};
use uuid::Uuid;
/// ---------- Rocket routes ----------
#[derive(Debug, Serialize, Deserialize, Clone, FromRow)]
pub struct ChatMsg {
pub id: Uuid,
pub display_name: Option<String>,
pub user_id: i64,
pub text: String,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", content = "data")]
pub enum ChatEvent {
SendMessage(ChatMsg),
/// for when a user explicitly edits a message
EditMessage { id: Uuid, msg: ChatMsg },
/// used for streaming content to a message
/// will not show up as edited
MessageAppendContent{ id: Uuid, content: String }
}
+1
View File
@@ -1,3 +1,4 @@
pub mod auth; pub mod auth;
pub mod user; pub mod user;
pub mod space; pub mod space;
pub mod event;
+35 -23
View File
@@ -1,68 +1,80 @@
use crate::api::chat::ChatMsg; use crate::model::event::ChatMsg;
use crate::repo::Repo; use crate::repo::Repo;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use sqlx::PgPool; use sqlx::PgPool;
use uuid::Uuid;
#[derive(Clone)] #[derive(Clone)]
pub struct MessageRepository { pub struct MessageRepository {
pool: PgPool pool: PgPool
} }
impl Repo for MessageRepository { impl MessageRepository {
type Target = ChatMsg; pub(crate) fn new(pool: PgPool) -> Self {
fn new(pool: PgPool) -> Self {
Self { pool } Self { pool }
} }
// TODO: caching with redis // TODO: caching with redis
async fn get_by_id(&self, id: i64) -> Option<Self::Target> { async fn get_by_id(&self, id: Uuid) -> Option<ChatMsg> {
sqlx::query!( sqlx::query!(
"SELECT u.username, u.nickname, u.id as user_id, m.content, m.created_at "SELECT m.id, u.username, u.nickname, u.id as user_id, m.content, m.created_at
FROM messages m FROM messages m
JOIN users u ON m.user_id = u.id JOIN users u ON m.user_id = u.id
WHERE m.id = $1", WHERE m.id = $1",
id id
).fetch_optional(&self.pool).await.ok().flatten().map(|row| ChatMsg { ).fetch_optional(&self.pool).await.ok().flatten().map(|row| ChatMsg {
id: row.id,
display_name: Some(row.nickname.unwrap_or(row.username)), display_name: Some(row.nickname.unwrap_or(row.username)),
user_id: row.user_id, user_id: row.user_id,
text: row.content, text: row.content,
timestamp: row.created_at, timestamp: row.created_at,
}) })
} }
}
impl MessageRepository {
// TODO! caching with redis // TODO! caching with redis
pub async fn create_new( pub async fn create_new(
&self, uid: i64, channel_id: i64, &self, msg: ChatMsg, channel_id: i64
text: &str, created_at: DateTime<Utc> ) -> Result<(), sqlx::Error> {
) -> Result<i64, sqlx::Error> {
sqlx::query!( sqlx::query!(
"INSERT INTO messages (channel_id, user_id, content, created_at) "INSERT INTO messages (id, channel_id, user_id, content, created_at)
VALUES ($1, $2, $3, $4) RETURNING id", VALUES ($1, $2, $3, $4, $5)",
msg.id,
channel_id, channel_id,
uid, msg.user_id,
msg.text,
msg.timestamp
).execute(&self.pool).await.map_err(|_| sqlx::Error::RowNotFound)?;
Ok(())
}
pub async fn update_text(&self, id: Uuid, text: &str) -> Result<(), sqlx::Error> {
sqlx::query!(
"UPDATE messages SET content = $1 WHERE id = $2",
text, text,
created_at id
).fetch_optional(&self.pool).await.and_then(|row| row.map(|r| r.id).ok_or(sqlx::Error::RowNotFound)) ).execute(&self.pool).await?;
Ok(())
} }
/// TODO: caching with redis /// TODO: caching with redis
pub async fn get_by_channel(&self, channel_id: i64, limit: usize) pub async fn get_latest_by_channel_desc(
-> Result<Vec<ChatMsg>, sqlx::Error> { &self, channel_id: i64, limit: usize, page: usize
) -> Result<Vec<ChatMsg>, sqlx::Error> {
sqlx::query!( sqlx::query!(
"SELECT u.username, u.nickname, u.id as user_id, m.content, m.created_at "SELECT m.id, u.username, u.nickname, u.id as user_id, m.content, m.created_at
FROM messages m FROM messages m
JOIN users u ON m.user_id = u.id JOIN users u ON m.user_id = u.id
WHERE m.channel_id = $1 WHERE m.channel_id = $1
ORDER BY m.created_at DESC LIMIT $2", ORDER BY m.created_at DESC
LIMIT $2 OFFSET $3",
channel_id, channel_id,
limit as i64 limit as i64,
page as i64
).fetch_all(&self.pool).await.map(|messages| { ).fetch_all(&self.pool).await.map(|messages| {
messages.into_iter().rev().map(|msg| { messages.into_iter().map(|msg| {
ChatMsg { ChatMsg {
id: msg.id,
display_name: Some(msg.nickname.unwrap_or(msg.username)), display_name: Some(msg.nickname.unwrap_or(msg.username)),
user_id: msg.user_id, user_id: msg.user_id,
text: msg.content, text: msg.content,
+104 -42
View File
@@ -1,12 +1,15 @@
use crate::api::chat::ChatMsg; use crate::model::event::{ChatEvent, ChatMsg};
use crate::error::{ApiResult, AppError}; use crate::error::{ApiResult, AppError};
use crate::repo::message_repo::MessageRepository; use crate::repo::message_repo::MessageRepository;
use crate::repo::{ChannelRepo, Repo, SpaceRepo, UserRepo}; use crate::repo::{ChannelRepo, Repo, SpaceRepo, UserRepo};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use std::collections::HashMap; use std::collections::{HashMap, VecDeque};
use std::sync::Arc; use std::sync::Arc;
use std::sync::mpsc::channel;
use std::time::Instant;
use tokio::sync::broadcast::Sender; use tokio::sync::broadcast::Sender;
use tokio::sync::{broadcast, Mutex}; use tokio::sync::{broadcast, Mutex};
use uuid::Uuid;
use crate::model::space::SpaceDto; use crate::model::space::SpaceDto;
use crate::svc::llm_service::LlmService; use crate::svc::llm_service::LlmService;
@@ -15,15 +18,13 @@ use crate::svc::llm_service::LlmService;
#[derive(Clone)] #[derive(Clone)]
pub struct ChatService { pub struct ChatService {
users: Arc<dyn UserRepo>, users: Arc<dyn UserRepo>,
channels: Arc<dyn ChannelRepo>, channel_repo: Arc<dyn ChannelRepo>,
spaces: Arc<dyn SpaceRepo>, spaces: Arc<dyn SpaceRepo>,
messages: MessageRepository, messages: MessageRepository,
llm: LlmService, llm: LlmService,
buffer_size: usize, buffer_size: usize,
senders: Arc<Mutex<HashMap<i64, Sender<ChatMsg>>>>, channels: Arc<Mutex<HashMap<i64, ChannelState>>>,
} }
impl ChatService { impl ChatService {
@@ -33,13 +34,13 @@ impl ChatService {
channels: Arc<dyn ChannelRepo>, spaces: Arc<dyn SpaceRepo>, channels: Arc<dyn ChannelRepo>, spaces: Arc<dyn SpaceRepo>,
) -> Self { ) -> Self {
Self { Self {
channels, channel_repo: channels,
spaces, spaces,
llm, llm,
users, users,
messages, messages,
buffer_size, buffer_size,
senders: Arc::new(Mutex::new(std::collections::HashMap::new())), channels: Arc::new(Mutex::new(std::collections::HashMap::new())),
} }
} }
@@ -50,7 +51,7 @@ impl ChatService {
let mut result = Vec::new(); let mut result = Vec::new();
for space in spaces { for space in spaces {
let channels = self.channels.get_by_space_id(space.id).await?; let channels = self.channel_repo.get_by_space_id(space.id).await?;
result.push(SpaceDto { result.push(SpaceDto {
channels, channels,
id: space.id, id: space.id,
@@ -65,8 +66,10 @@ impl ChatService {
Ok(result) Ok(result)
} }
pub async fn get_messages(&self, channel_id: i64, limit: usize) -> ApiResult<Vec<ChatMsg>> { pub async fn fetch_latest_messages_desc(&self, channel_id: i64, limit: usize) -> ApiResult<Vec<ChatMsg>> {
let messages = self.messages.get_by_channel(channel_id, limit).await?; const PAGE: usize = 0;
let messages = self.messages.get_latest_by_channel_desc(channel_id, limit, PAGE).await?;
Ok(messages) Ok(messages)
} }
@@ -113,6 +116,7 @@ impl ChatService {
.ok_or(AppError::NotFound)?; .ok_or(AppError::NotFound)?;
let message = ChatMsg { let message = ChatMsg {
id: Uuid::new_v4(),
display_name: Some(user display_name: Some(user
.nickname.clone() .nickname.clone()
.unwrap_or_else(|| user.username.clone())), .unwrap_or_else(|| user.username.clone())),
@@ -122,66 +126,124 @@ impl ChatService {
}; };
self.publish(channel_id, message.clone()).await; self.publish(channel_id, message.clone()).await;
self.messages.create_new(message.clone(), channel_id).await?;
let _msg_id = self.messages.create_new(uid, channel_id, text, created_at).await?;
// TODO: caching w redis at repository layer // TODO: caching w redis at repository layer
let svc_instance = self.clone(); if !message.text.starts_with("/ask ") {
return Ok(());
let Some(text) = text.strip_prefix("/ask ") else {
return Ok(())
};
if !svc_instance.llm.enabled() {
return Ok(())
} }
let svc_instance = self.clone();
if !svc_instance.llm.enabled() {
return Ok(());
}
let context = self.get_history(channel_id, 25).await?;
tokio::spawn(async move { tokio::spawn(async move {
let sender = match svc_instance.channels.lock().await.get(&channel_id) {
Some(s) => s.get_sender(),
None => return,
};
let response = svc_instance.llm let response = svc_instance.llm
.query(&message) .query(&message, &context, sender)
.await; .await;
if let Ok(reply) = response { let Ok(reply) = response else {
tracing::warn!("Error contacting LLM: {:?}", response);
return;
};
tracing::info!("LLM reply: {}", reply.text);
svc_instance.publish(channel_id, reply.clone()).await;
// TODO: cache response (or do with redis!) // TODO: cache response (or do with redis!)
if let Err(e) = svc_instance.messages if let Err(e) = svc_instance.messages.create_new(reply, channel_id).await {
.create_new(reply.user_id, channel_id, &reply.text, reply.timestamp).await {
tracing::error!("Failed to persist LLM reply: {}", e); tracing::error!("Failed to persist LLM reply: {}", e);
} }
tracing::info!("LLM reply persisted"); tracing::debug!("Full LLM reply persisted");
} else {
tracing::warn!("Error contacting LLM: {:?}", response);
}
}); });
Ok(()) Ok(())
} }
async fn get_history(&self, channel_id: i64, limit: usize) -> ApiResult<Vec<ChatMsg>> {
let mut map = self.channels.lock().await;
if let Some(channel) = map.get(&channel_id) && !channel.cache.is_empty() {
Ok(channel.history().clone().into_iter().take(limit).collect())
} else {
let messages: Vec<_> = self.messages.get_latest_by_channel_desc(channel_id, limit, 0).await?.into_iter().rev().collect();
map.insert(channel_id,
ChannelState::new(self.buffer_size, Some(messages.clone().into()))
);
Ok(messages)
}
}
/// Subscribe to the specified channel. /// Subscribe to the specified channel.
pub async fn subscribe(&self, channel_id: i64) -> broadcast::Receiver<ChatMsg> { pub async fn subscribe(&self, channel_id: i64) -> broadcast::Receiver<ChatEvent> {
let mut map = self.senders.lock().await; let mut map = self.channels.lock().await;
let sender = map let channel = map
.entry(channel_id) .entry(channel_id)
.or_insert_with(|| broadcast::channel::<ChatMsg>(self.buffer_size).0); .or_insert_with(|| ChannelState::new(self.buffer_size, None));
sender.subscribe() channel.subscribe()
} }
// Private helper methods // Private helper methods
/// Publish a message to the specified channel. /// Publish a message to the specified channel.
async fn publish(&self, channel_id: i64, msg: ChatMsg) { async fn publish(&self, channel_id: i64, msg: ChatMsg) {
let mut map = self.senders.lock().await; let mut map = self.channels.lock().await;
let sender = map let channel = map
.entry(channel_id) .entry(channel_id)
.or_insert_with(|| broadcast::channel::<ChatMsg>(self.buffer_size).0); .or_insert_with(|| ChannelState::new(self.buffer_size, None));
let _ = sender.send(msg); channel.send(msg);
}
}
#[derive(Clone)]
pub struct ChannelState {
sender: Sender<ChatEvent>,
cache: VecDeque<ChatMsg>,
last_updated: Instant,
}
impl ChannelState {
const MAX_HISTORY_SIZE: usize = 100;
#[must_use]
pub fn new(buffer_size: usize, history: Option<VecDeque<ChatMsg>>) -> Self {
Self {
sender: broadcast::channel(buffer_size).0,
cache: history.unwrap_or_default(),
last_updated: Instant::now(),
}
} }
pub fn history(&self) -> &VecDeque<ChatMsg> {
&self.cache
}
pub fn get_sender(&self) -> Sender<ChatEvent> {
self.sender.clone()
}
#[must_use]
pub fn send(&mut self, msg: ChatMsg) {
while self.cache.len() >= Self::MAX_HISTORY_SIZE {
self.cache.pop_front();
}
self.cache.push_back(msg.clone());
if self.sender.send(ChatEvent::SendMessage(msg)).is_err() {
tracing::warn!("Sent message to channel with no subscribers");
}
}
#[must_use]
pub fn subscribe(&self) -> broadcast::Receiver<ChatEvent> {
self.sender.subscribe()
}
} }
+119 -41
View File
@@ -1,3 +1,13 @@
use std::env;
use std::sync::LazyLock;
use futures_util::StreamExt;
use serde::{Deserialize, Serialize};
use tokio::sync::broadcast::Sender;
use uuid::Uuid;
use crate::model::event::{ChatEvent, ChatMsg};
use crate::error::{ApiResult, AppError};
#[derive(Clone)] #[derive(Clone)]
pub struct LlmService; pub struct LlmService;
@@ -15,71 +25,139 @@ impl LlmService {
LMSTUDIO_URL.is_some() LMSTUDIO_URL.is_some()
} }
pub async fn query(&self, message: &ChatMsg) -> ApiResult<ChatMsg> { pub async fn query(&self, message: &ChatMsg, context: &[ChatMsg], sender: Sender<ChatEvent>) -> ApiResult<ChatMsg> {
let Some(url) = LMSTUDIO_URL.clone() else { let Some(url) = LMSTUDIO_URL.clone() else {
return Err(AppError::internal("AI not enabled!")) return Err(AppError::internal("AI not enabled!"))
}; };
let model = LMSTUDIO_MODEL.clone().unwrap_or_else(|| "gpt-oss-20b".into()); let reply_id = Uuid::new_v4();
let timestamp = chrono::Utc::now();
let client = reqwest::Client::new(); let mut reply = ChatMsg {
id: reply_id,
// Build the request body display_name: Some("llm".into()),
let payload = LlmRequest { user_id: 0,
model, // whatever model you run locally text: String::new(),
messages: vec![Message { timestamp,
role: "user".into(),
content: message.text.clone(),
}],
}; };
let _ = sender.send(ChatEvent::SendMessage(reply.clone()));
// POST to lmstudio (default 127.0.0.1:1234) let mut messages: Vec<Message> = Vec::new();
let resp = client let system_prompt = format!(
"You are a helpful assistant in a group chat. \
You are talking to '{}'. \
Keep responses concise and conversational.",
message.display_name.as_deref().unwrap_or("unknown"),
);
messages.push(Message { role: "system".into(), content: system_prompt });
for msg in context {
let role = if msg.user_id == 0 {
"assistant" // your LLM user_id convention
} else {
"user"
};
messages.push(Message {
role: role.into(),
content: format!(
"{}: {}",
msg.display_name.as_deref().unwrap_or("unknown"),
msg.text
),
});
}
messages.push(Message {
role: "user".into(),
content: format!(
"{}: {}",
message.display_name.as_deref().unwrap_or("unknown"),
message.text.trim_start_matches("/ask ") // strip the command prefix
),
});
let Ok(resp) = reqwest::Client::new()
.post(url) .post(url)
.json(&payload) .json(&LlmRequest {
think: false,
model: LMSTUDIO_MODEL
.clone()
.unwrap_or_else(|| "gpt-oss-20b".into()),
messages,
stream: true,
})
.send() .send()
.await .await
.map_err(|_| AppError::internal("Failed to make request to LLM server"))?; else {
tracing::warn!("Failed to reach LLM");
let _ = sender.send(ChatEvent::MessageAppendContent {
id: reply_id,
content: String::from("I'm not available right now. Please try again later.")
});
return Err(AppError::internal("Failed to reach LLM"));
};
// The API returns a JSON with `choices[].message.content`
#[derive(Deserialize)] let mut full_text = String::new();
struct LlmResponse { let mut buffer = String::new();
choices: Vec<Choice>, let mut stream = resp.bytes_stream();
}
#[derive(Deserialize)] while let Some(chunk) = stream.next().await {
struct Choice { let chunk = chunk.map_err(|_| AppError::internal("Stream error"))?;
message: Message, buffer.push_str(&String::from_utf8_lossy(&chunk));
while let Some(pos) = buffer.find('\n') {
let line = buffer[..pos].trim().to_string();
buffer = buffer[pos + 1..].to_string();
if line == "data: [DONE]" {
break;
} }
let llm_resp: LlmResponse = resp if let Some(json) = line.strip_prefix("data: ")
.json() && let Ok(parsed) = serde_json::from_str::<StreamingResponse>(json)
.await && let Some(content) = parsed.choices[0].delta.content.as_ref() {
.map_err(|_| AppError::internal("Failed to parse LLM response"))?; full_text.push_str(content);
let _ = sender.send(ChatEvent::MessageAppendContent {
id: reply_id,
content: content.clone(),
});
}
}
}
Ok(ChatMsg { reply.text = full_text;
display_name: Some(String::from("llm")),
user_id: 0, Ok(reply)
text: llm_resp.choices[0].message.content.clone(),
timestamp: chrono::Utc::now(),
})
} }
} }
use std::env;
use std::sync::LazyLock;
// src/llm.rs
use serde::{Deserialize, Serialize};
use crate::api::chat::ChatMsg;
use crate::error::{ApiResult, AppError};
use crate::svc::chat_svc::ChatService;
#[derive(Serialize)] #[derive(Serialize)]
struct LlmRequest { struct LlmRequest {
model: String, model: String,
messages: Vec<Message>, messages: Vec<Message>,
stream: bool,
think: bool,
} }
#[derive(Deserialize)]
struct StreamingResponse {
choices: Vec<StreamingChoice>,
}
#[derive(Deserialize)]
struct StreamingChoice {
delta: Delta,
#[serde(default)]
finish_reason: Option<String>,
}
#[derive(Deserialize)]
struct Delta {
#[serde(default)]
content: Option<String>,
}
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
struct Message { struct Message {
role: String, // "user" or "assistant" role: String, // "user" or "assistant"
+1
View File
@@ -4,3 +4,4 @@ pub mod settings_svc;
pub mod user_svc; pub mod user_svc;
pub mod access_token_svc; pub mod access_token_svc;
pub mod llm_service; pub mod llm_service;
pub mod relationship_svc;
+6
View File
@@ -0,0 +1,6 @@
pub struct RelationshipService {}
impl RelationshipService {
pub fn new() -> Self { Self {} }
}
+228
View File
@@ -0,0 +1,228 @@
import argparse
import json
import threading
import time
from dataclasses import dataclass
from getpass import getpass
from typing import List
import requests
BASE_URL = "http://localhost:8000"
@dataclass
class AuthResult:
token: str
def signup(session: requests.Session, email: str, username: str, password: str, access_token: str) -> None:
url = f"{BASE_URL}/api/signup"
payload = {
"email": email,
"username": username,
"password": password,
"access_token": access_token,
}
resp = session.post(url, json=payload, timeout=10)
resp.raise_for_status()
def login(session: requests.Session, username: str, password: str) -> AuthResult:
url = f"{BASE_URL}/api/login"
payload = {
"username": username,
"password": password,
}
resp = session.post(url, json=payload, timeout=10)
resp.raise_for_status()
data = resp.json()
token = data.get("token")
if not token:
raise RuntimeError(f"Login response did not contain token: {data}")
return AuthResult(token=token)
def post_message(session: requests.Session, channel_id: int, token: str, text: str, display_name: str, user_id: int) -> None:
url = f"{BASE_URL}/api/chat/{channel_id}"
payload = {
"display_name": display_name,
"user_id": user_id,
"text": text,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
}
headers = {"Authorization": f"Bearer {token}"}
resp = session.post(url, json=payload, headers=headers, timeout=10)
resp.raise_for_status()
def read_sse_messages(
session: requests.Session,
channel_id: int,
token: str,
expected_count: int,
timeout_s: int,
capture_live_messages: threading.Event,
) -> List[dict]:
url = f"{BASE_URL}/api/events/{channel_id}"
headers = {
"Authorization": f"Bearer {token}",
"Accept": "text/event-stream",
}
received: List[dict] = []
deadline = time.monotonic() + timeout_s
try:
with session.get(url, headers=headers, stream=True, timeout=(5, timeout_s)) as resp:
resp.raise_for_status()
event_data_lines: List[str] = []
for raw_line in resp.iter_lines(decode_unicode=True):
if time.monotonic() > deadline:
break
if raw_line is None:
continue
line = raw_line.strip()
if not line:
if event_data_lines:
joined = "\n".join(event_data_lines)
event_data_lines.clear()
try:
obj = json.loads(joined)
except json.JSONDecodeError:
continue
if capture_live_messages.is_set():
received.append(obj)
if expected_count > 0 and len(received) >= expected_count:
break
else:
print(f"Discarding message: {obj}")
continue
if line.startswith("data:"):
event_data_lines.append(line[len("data:"):].strip())
except requests.exceptions.Timeout:
print("Timeout while reading SSE.")
except requests.exceptions.RequestException as exc:
print(f"Error reading SSE: {exc}")
return received
def prompt_nonempty(label: str, secret: bool = False) -> str:
while True:
value = getpass(label) if secret else input(label)
value = value.strip()
if value:
return value
print("Please enter a value.")
def main() -> int:
parser = argparse.ArgumentParser(description="Chat integration test against localhost:8000")
parser.add_argument("--existing-account", action="store_true",
help="Skip signup and only log in with an existing account")
parser.add_argument("--email", default=None)
parser.add_argument("--username", default=None)
parser.add_argument("--password", default=None)
parser.add_argument("--access-token", default=None,
help="Required only for signup mode")
parser.add_argument("--channel-id", type=int, default=1)
parser.add_argument("--message-count", type=int, default=5)
parser.add_argument("--timeout", type=int, default=15)
args = parser.parse_args()
session = requests.Session()
if args.existing_account:
username = args.username or prompt_nonempty("Username: ")
password = args.password or prompt_nonempty("Password: ", secret=True)
else:
email = args.email or prompt_nonempty("Email: ")
username = args.username or prompt_nonempty("Username: ")
password = args.password or prompt_nonempty("Password: ", secret=True)
access_token = args.access_token or prompt_nonempty("Access token: ")
print("[1/5] Signing up...")
try:
signup(session, email, username, password, access_token)
print(" signup ok")
except requests.HTTPError as e:
print(f" signup returned HTTP error: {e}")
print(" continuing to login...")
print("[2/5] Logging in...")
auth = login(session, username, password)
print(" login ok")
print(f" token: {auth.token[:12]}...")
print("[3/5] Opening event stream...")
received_messages: List[dict] = []
capture_live_messages = threading.Event()
stream_done = threading.Event()
def stream_reader() -> None:
nonlocal received_messages
try:
received_messages = read_sse_messages(
session=session,
channel_id=args.channel_id,
token=auth.token,
expected_count=args.message_count,
timeout_s=args.timeout,
capture_live_messages=capture_live_messages,
)
finally:
stream_done.set()
t = threading.Thread(target=stream_reader, daemon=True)
t.start()
# Give the server time to flush backlog on this same stream connection.
time.sleep(1.0)
print("[4/5] Starting to capture live messages and sending messages...")
capture_live_messages.set()
sent_texts = [f"Message {i}" for i in range(args.message_count)]
for i, text in enumerate(sent_texts):
post_message(
session=session,
channel_id=args.channel_id,
token=auth.token,
text=text,
display_name=username,
user_id=1,
)
print(f" sent {i + 1}/{args.message_count}: {text}")
time.sleep(0.1)
stream_done.wait(timeout=args.timeout)
t.join(timeout=1)
print("\nReceived messages:")
for i, msg in enumerate(received_messages, start=1):
print(f" {i}. {msg}")
received_texts = [m.get("text") for m in received_messages if isinstance(m, dict)]
for text in sent_texts:
if text not in received_texts:
print(f"\nFAIL: missing message: {text}")
return 1
print("\nPASS: login and message delivery test succeeded.")
return 0
if __name__ == "__main__":
raise SystemExit(main())