4 Commits

Author SHA1 Message Date
zxq5 7e001d8769 idk 2026-06-03 19:12:23 +01:00
zxq5 2f34976f3e Merge remote-tracking branch 'origin/dev' into dev 2026-04-11 00:10:09 +01:00
zxq5 d1208f7e39 frontend v0.4.1-2
- added invite section to UI and some general bug fixes
2026-04-11 00:09:47 +01:00
zxq5 d6ba875297 addedd RELEASE_MODE=1 to run var to prevent crash in absence of .env
file
2026-04-08 00:05:54 +01:00
42 changed files with 1026 additions and 280 deletions
+1 -5
View File
@@ -1,12 +1,9 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools">
xmlns:tools="http://tools.android.com/tools">
<uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.POST_NOTIFICATIONS"/>
<uses-permission android:name="android.permission.FOREGROUND_SERVICE"/>
<uses-permission android:name="android.permission.FOREGROUND_SERVICE_DATA_SYNC"/>
<application
android:name=".ChatApplication"
@@ -22,7 +19,6 @@
<service
android:name=".core.service.MessageStreamService"
android:foregroundServiceType="dataSync"
android:exported="false"/>
<activity
@@ -1,9 +1,13 @@
package dev.zxq5.chatapp.android
import android.Manifest
import android.os.Build
import android.os.Bundle
import androidx.activity.ComponentActivity
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.compose.setContent
import androidx.activity.enableEdgeToEdge
import androidx.activity.result.contract.ActivityResultContracts
import androidx.compose.foundation.border
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.fillMaxSize
@@ -62,14 +66,37 @@ class MainActivity : ComponentActivity() {
val currentScreen by chatViewModel.currentScreen.collectAsState()
val selectedChannelId by chatViewModel.channelId.collectAsState()
// Permission request launcher
val launcher = rememberLauncherForActivityResult(
contract = ActivityResultContracts.RequestPermission(),
onResult = { isGranted ->
if (isGranted && authState == AuthState.Authenticated) {
MessageStreamService.start(this@MainActivity)
}
}
)
LaunchedEffect(authState) {
when (authState) {
AuthState.Authenticated -> MessageStreamService.start(this@MainActivity)
AuthState.Authenticated -> {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) {
launcher.launch(Manifest.permission.POST_NOTIFICATIONS)
}
MessageStreamService.start(this@MainActivity)
chatViewModel.loadAccessibleChannels()
}
AuthState.Unauthenticated -> MessageStreamService.stop(this@MainActivity)
AuthState.AwaitingTotp -> {}
}
}
LaunchedEffect(Unit) {
chatViewModel.onUnauthorized = {
authViewModel.logout()
chatViewModel.clearChat()
}
}
LaunchedEffect(Unit) {
intent.getIntExtra("channel_id", -1).takeIf { it != -1 }?.let {
chatViewModel.switchChannel(it.toLong())
@@ -1,7 +1,9 @@
@file:OptIn(ExperimentalUuidApi::class)
package dev.zxq5.chatapp.android.api
import dev.zxq5.chatapp.android.BuildConfig.BASE_URL
import dev.zxq5.chatapp.android.api.model.Message
import dev.zxq5.chatapp.android.api.model.ChatEvent
import dev.zxq5.chatapp.android.api.model.SendMessage
import dev.zxq5.chatapp.android.api.model.SpaceDto
import io.ktor.client.HttpClient
@@ -25,6 +27,8 @@ import kotlinx.coroutines.flow.flow
import kotlinx.serialization.json.Json
import kotlin.time.Clock
import kotlin.time.ExperimentalTime
import kotlin.uuid.ExperimentalUuidApi
import kotlin.uuid.Uuid
class ChatClient(private val token: String) {
@@ -45,18 +49,18 @@ class ChatClient(private val token: String) {
suspend fun sendMessage(channelId: Long, userId: Int, text: String) {
http.post("${BASE_URL}/api/chat/$channelId") {
contentType(ContentType.Application.Json)
setBody(SendMessage(user_id = userId, text = text, timestamp = Clock.System.now()))
setBody(SendMessage(id = Uuid.random(), user_id = userId, text = text, timestamp = Clock.System.now()))
}
}
fun messageStream(channelId: Long): Flow<Message> = flow {
fun eventStream(channelId: Long): Flow<ChatEvent> = flow {
http.prepareGet("${BASE_URL}/api/events/$channelId").execute { response ->
val channel = response.bodyAsChannel()
while (!channel.isClosedForRead) {
val line = channel.readLine() ?: break
if (line.startsWith("data:")) {
val json = line.removePrefix("data:").trim()
runCatching { Json.decodeFromString<Message>(json) }
runCatching { Json.decodeFromString<ChatEvent>(json) }
.onSuccess { emit(it) }
}
}
@@ -4,6 +4,7 @@ import android.util.Log
import dev.zxq5.chatapp.android.BuildConfig.BASE_URL
import dev.zxq5.chatapp.android.api.model.AccountDeleteRequest
import dev.zxq5.chatapp.android.api.model.DisplayNameRequest
import dev.zxq5.chatapp.android.api.model.InviteRequest
import dev.zxq5.chatapp.android.api.model.PasswordChangeRequest
import dev.zxq5.chatapp.android.api.model.QrResponse
import dev.zxq5.chatapp.android.api.model.TOTPSixDigitCode
@@ -45,6 +46,22 @@ class SettingsClient(private val token: String) {
}
}
suspend fun createInvite(request: InviteRequest): ApiResult<String> {
return try {
val response = http.post("${BASE_URL}/api/invite") {
contentType(ContentType.Application.Json)
setBody(request)
}
if (response.status.isSuccess()) {
ApiResult.Success(response.body<String>())
} else {
ApiResult.HttpError(response.status.value, "Failed to create invite")
}
} catch (e: Exception) {
ApiResult.NetworkError(e.localizedMessage ?: "Network error")
}
}
suspend fun getTotpQr(password: String): ApiResult<QrResponse> {
return try {
val response = http.post("${BASE_URL}/api/totp.jpg") {
@@ -0,0 +1,60 @@
@file:OptIn(ExperimentalUuidApi::class, ExperimentalTime::class)
package dev.zxq5.chatapp.android.api.model
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.SerialName
import kotlinx.serialization.json.JsonClassDiscriminator
import kotlin.time.ExperimentalTime
import kotlin.time.Instant
import kotlinx.serialization.Serializable
import kotlin.uuid.ExperimentalUuidApi
import kotlin.uuid.Uuid
@OptIn(ExperimentalSerializationApi::class)
@Serializable
@JsonClassDiscriminator("type")
sealed class ChatEvent {
@Serializable
@SerialName("SendMessage")
data class SendMessage(
val data: Message
) : ChatEvent()
@Serializable
@SerialName("EditMessage")
data class EditMessage(
val data: EditMessageContent
) : ChatEvent()
@Serializable
@SerialName("MessageAppendContent")
data class MessageAppendContent(
val data: AppendContent
) : ChatEvent()
}
// tuple variants like (i64, ChatMsg) and (i64, String)
// need wrapper classes since kotlinx can't deserialise
// bare JSON arrays into data classes directly
@Serializable
data class EditMessageContent(
val id: Uuid,
val message: Message
)
@Serializable
data class AppendContent (
val id: Uuid,
val content: String
)
@Serializable
data class Message (
val id: Uuid,
val user_id: Int,
val display_name: String,
val text: String,
val timestamp: Instant
)
@@ -0,0 +1,13 @@
package dev.zxq5.chatapp.android.api.model
import kotlinx.serialization.Serializable
import kotlin.time.ExperimentalTime
import kotlin.time.Instant
@Serializable
data class InviteRequest @OptIn(ExperimentalTime::class) constructor(
val name: String,
val max_uses: Int,
val expiry_date: Instant,
val start_date: Instant
)
@@ -1,4 +1,4 @@
package dev.zxq5.chatapp.android.model
package dev.zxq5.chatapp.android.api.model
sealed class LoginState {
object Idle : LoginState()
@@ -1,13 +0,0 @@
package dev.zxq5.chatapp.android.api.model
import kotlinx.serialization.Serializable
import kotlin.time.ExperimentalTime
import kotlin.time.Instant
@Serializable
data class Message @OptIn(ExperimentalTime::class) constructor(
val user_id: Int,
val display_name: String,
val text: String,
val timestamp: Instant
)
@@ -3,9 +3,12 @@ package dev.zxq5.chatapp.android.api.model
import kotlinx.serialization.Serializable
import kotlin.time.ExperimentalTime
import kotlin.time.Instant
import kotlin.uuid.ExperimentalUuidApi
import kotlin.uuid.Uuid
@Serializable
data class SendMessage @OptIn(ExperimentalTime::class) constructor(
data class SendMessage @OptIn(ExperimentalTime::class, ExperimentalUuidApi::class) constructor(
val id: Uuid,
val user_id: Int,
val text: String,
val timestamp: Instant
@@ -3,10 +3,12 @@ package dev.zxq5.chatapp.android.core.data
import android.content.Context
import android.content.SharedPreferences
import android.util.Base64
import android.util.Log
import androidx.core.content.edit
import androidx.security.crypto.EncryptedSharedPreferences
import androidx.security.crypto.MasterKey
import org.json.JSONObject
import java.time.Instant
private const val KEY = "auth_token"
private const val TWOFA_KEY = "twofa_enabled"
@@ -27,11 +29,37 @@ class TokenStore(appContext: Context) {
)
}
fun save(token: String) =
prefs().edit { putString(KEY, token) }
fun save(token: String) {
Log.d("TokenStore", "Saving token: $token")
prefs().edit { putString(KEY, token) }
}
fun get(): String? {
val ret = prefs().getString(KEY, null)
Log.d("TokenStore", "Retrieved token: $ret")
return ret
}
fun isExpired(): Boolean {
val token = get() ?: return true
return try {
val payload = token.split(".")[1]
val padded = payload + "==".take((4 - payload.length % 4) % 4)
val jsonString = String(Base64.decode(padded, Base64.URL_SAFE))
val json = JSONObject(jsonString)
if (json.has("exp")) {
val exp = json.getLong("exp")
val now = Instant.now().epochSecond
now >= exp
} else {
false // If no exp claim, assume not expired or handle differently
}
} catch (e: Exception) {
true // If we can't parse it, treat it as expired
}
}
fun get(): String? =
prefs().getString(KEY, null)
fun save2faEnabled( enabled: Boolean) =
prefs().edit { putBoolean(TWOFA_KEY, enabled) }
@@ -3,10 +3,10 @@ package dev.zxq5.chatapp.android.core.service
import android.app.Service
import android.content.Context
import android.content.Intent
import android.os.Build
import android.os.IBinder
import android.util.Log
import dev.zxq5.chatapp.android.ChatApplication
import dev.zxq5.chatapp.android.api.model.ChatEvent
import dev.zxq5.chatapp.android.data.repository.ChatRepository
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
@@ -15,21 +15,17 @@ import kotlinx.coroutines.cancel
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.launch
// core/service/MessageStreamService.kt
class MessageStreamService : Service() {
private val serviceScope = CoroutineScope(SupervisorJob() + Dispatchers.IO)
private lateinit var notificationService: NotificationService
private lateinit var chatRepository: ChatRepository
// which channel the user is currently looking at
// set by the ViewModel when the user opens/closes a channel
var activeChannelId: Long? = null
set(value) {
field = value
Log.d("Service", "activeChannelId set to $value")
if (value != null) {
// restart stream with new channel
currentStreamJob?.cancel()
observeMessages()
}
@@ -42,12 +38,8 @@ class MessageStreamService : Service() {
fun start(context: Context) {
val intent = Intent(context, MessageStreamService::class.java)
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
context.startForegroundService(intent)
} else {
context.startService(intent)
}
}
fun stop(context: Context) {
context.stopService(Intent(context, MessageStreamService::class.java))
@@ -62,33 +54,29 @@ class MessageStreamService : Service() {
}
override fun onStartCommand(intent: Intent?, flags: Int, startId: Int): Int {
startForeground(
NotificationService.FOREGROUND_NOTIFICATION_ID,
notificationService.buildForegroundNotification()
)
observeMessages()
return START_STICKY // restart if killed
return START_STICKY
}
private fun observeMessages() {
val channelId = activeChannelId ?: chatRepository.getLastActiveChannel()
Log.d("Service", "observeMessages called, channelId=$channelId")
if (channelId == null) {
Log.d("Service", "No channel to observe, waiting for switchChannel")
return
}
if (channelId == null) return
Log.d("Service", "Starting stream for channel $channelId")
currentStreamJob = serviceScope.launch {
chatRepository.messageStream(channelId)
chatRepository.eventStream(channelId)
.catch { e -> Log.e("Service", "Stream error", e) }
.collect { message ->
if (!ChatApplication.AppState.isInForeground) { // no channel focused, always notify
notificationService.showMessageNotification(
conversationId = activeChannelId.toString(),
senderName = message.display_name,
messagePreview = message.text.take(80)
.collect { event ->
// Only show notification when an event (new message) is received
// and the app is not in the foreground on this channel.
if (!ChatApplication.AppState.isInForeground || activeChannelId != channelId) {
when (event) {
is ChatEvent.SendMessage -> notificationService.showMessageNotification(
conversationId = channelId.toString(),
senderName = event.data.display_name,
messagePreview = event.data.text
)
else -> {}
}
}
}
}
@@ -15,51 +15,41 @@ class NotificationService(private val context: Context) {
companion object {
const val CHANNEL_ID = "messages"
const val FOREGROUND_NOTIFICATION_ID = 1 // ← this needs to exist
const val SERVICE_CHANNEL_ID = "service"
const val FOREGROUND_NOTIFICATION_ID = 1
}
private val manager = context.getSystemService(NotificationManager::class.java)
fun createChannels() {
// channel for new message notifications
val messageChannel = NotificationChannel(
CHANNEL_ID,
"Messages",
NotificationManager.IMPORTANCE_HIGH
).apply {
enableVibration(true)
fun createForegroundNotification(): Notification {
val intent = Intent(context, MainActivity::class.java).apply {
flags = Intent.FLAG_ACTIVITY_NEW_TASK or Intent.FLAG_ACTIVITY_CLEAR_TASK
}
// channel for the persistent foreground service notification
// low importance so it doesn't make noise
val serviceChannel = NotificationChannel(
"service",
"Background connection",
NotificationManager.IMPORTANCE_LOW
val pendingIntent = PendingIntent.getActivity(
context,
0,
intent,
PendingIntent.FLAG_IMMUTABLE
)
val mgr = context.getSystemService(NotificationManager::class.java)
mgr.createNotificationChannel(messageChannel)
mgr.createNotificationChannel(serviceChannel)
}
fun buildForegroundNotification(): Notification {
return NotificationCompat.Builder(context, "service")
return NotificationCompat.Builder(context, SERVICE_CHANNEL_ID)
.setSmallIcon(R.drawable.ic_notification)
.setContentTitle("chatapp")
.setContentText("Connected")
.setContentTitle("Chat App")
.setContentText("Connecting to message stream...")
.setPriority(NotificationCompat.PRIORITY_LOW)
.setCategory(NotificationCompat.CATEGORY_SERVICE)
.setContentIntent(pendingIntent)
.setOngoing(true)
.setSilent(true)
.build()
}
fun showMessageNotification(
conversationId: String,
senderName: String,
messagePreview: String, // for E2E this would be "New message" — no plaintext
messagePreview: String,
notificationId: Int = conversationId.hashCode()
) {
// intent that opens the app to the right conversation when tapped
val intent = Intent(context, MainActivity::class.java).apply {
flags = Intent.FLAG_ACTIVITY_SINGLE_TOP
putExtra("conversation_id", conversationId)
@@ -72,13 +62,13 @@ class NotificationService(private val context: Context) {
PendingIntent.FLAG_UPDATE_CURRENT or PendingIntent.FLAG_IMMUTABLE
)
val notification = NotificationCompat.Builder(context, "messages")
val notification = NotificationCompat.Builder(context, CHANNEL_ID)
.setSmallIcon(R.drawable.ic_notification)
.setContentTitle(senderName)
.setContentText(messagePreview)
.setPriority(NotificationCompat.PRIORITY_HIGH)
.setContentIntent(pendingIntent)
.setAutoCancel(true) // dismiss on tap
.setAutoCancel(true)
.build()
manager.notify(notificationId, notification)
@@ -56,6 +56,10 @@ class AuthRepository(
fun getAuthState(): AuthState {
val token = tokenStore.get() ?: return AuthState.Unauthenticated
if (tokenStore.isExpired()) {
tokenStore.clear()
return AuthState.Unauthenticated
}
return when (getScopeFromToken(token)) {
TokenScope.FULL -> AuthState.Authenticated
TokenScope.TOTP_PENDING -> AuthState.AwaitingTotp
@@ -1,6 +1,7 @@
package dev.zxq5.chatapp.android.data.repository
import dev.zxq5.chatapp.android.api.ChatClient
import dev.zxq5.chatapp.android.api.model.ChatEvent
import dev.zxq5.chatapp.android.core.data.TokenStore
import dev.zxq5.chatapp.android.api.model.Message
import dev.zxq5.chatapp.android.api.model.SpaceDto
@@ -43,8 +44,8 @@ class ChatRepository(private val tokenStore: TokenStore) {
getChatClient()?.sendMessage(channelId, userId, text)
}
fun messageStream(channelId: Long): Flow<Message> {
fun eventStream(channelId: Long): Flow<ChatEvent> {
_lastActiveChannel = channelId
return getChatClient()?.messageStream(channelId) ?: emptyFlow()
return getChatClient()?.eventStream(channelId) ?: emptyFlow()
}
}
@@ -2,6 +2,7 @@ package dev.zxq5.chatapp.android.data.repository
import dev.zxq5.chatapp.android.api.model.QrResponse
import dev.zxq5.chatapp.android.api.SettingsClient
import dev.zxq5.chatapp.android.api.model.InviteRequest
import dev.zxq5.chatapp.android.api.model.TotpStatus
import dev.zxq5.chatapp.android.core.data.TokenStore
import dev.zxq5.chatapp.android.core.error.ApiResult
@@ -25,6 +26,10 @@ class SettingsRepository(private val tokenStore: TokenStore) {
_lastToken = null
}
suspend fun createInvite(request: InviteRequest): ApiResult<String> {
return getSettingsClient()?.createInvite(request) ?: ApiResult.NetworkError("Not authenticated")
}
suspend fun getTotpQr(password: String): ApiResult<QrResponse?> {
val settingsClient = getSettingsClient() ?: return ApiResult.NetworkError("Not authenticated")
return settingsClient.getTotpQr(password)
@@ -3,7 +3,7 @@ package dev.zxq5.chatapp.android.feature.auth
import androidx.compose.runtime.Composable
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import dev.zxq5.chatapp.android.model.LoginState
import dev.zxq5.chatapp.android.api.model.LoginState
@Composable
fun AuthScreen(viewModel: AuthViewModel) {
@@ -7,7 +7,7 @@ import dev.zxq5.chatapp.android.data.repository.AuthRepository
import dev.zxq5.chatapp.android.data.repository.LoginResult
import dev.zxq5.chatapp.android.data.repository.SignupResult
import dev.zxq5.chatapp.android.data.repository.AuthState
import dev.zxq5.chatapp.android.model.LoginState
import dev.zxq5.chatapp.android.api.model.LoginState
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.launch
@@ -28,7 +28,7 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.unit.dp
import dev.zxq5.chatapp.android.model.LoginState
import dev.zxq5.chatapp.android.api.model.LoginState
import dev.zxq5.chatapp.android.ui.components.TextField
@Composable
@@ -28,7 +28,7 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.unit.dp
import dev.zxq5.chatapp.android.model.LoginState
import dev.zxq5.chatapp.android.api.model.LoginState
import dev.zxq5.chatapp.android.ui.components.TextField
@Composable
@@ -1,20 +1,25 @@
@file:OptIn(ExperimentalUuidApi::class)
package dev.zxq5.chatapp.android.feature.chat
import android.util.Log
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import dev.zxq5.chatapp.android.api.model.Channel
import dev.zxq5.chatapp.android.api.model.ChatEvent
import dev.zxq5.chatapp.android.data.repository.ChatRepository
import dev.zxq5.chatapp.android.api.model.Message
import dev.zxq5.chatapp.android.api.model.Space
import dev.zxq5.chatapp.android.api.model.SpaceDto
import dev.zxq5.chatapp.android.core.service.MessageStreamService
import io.ktor.client.plugins.ResponseException
import io.ktor.http.HttpStatusCode
import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch
import kotlin.time.ExperimentalTime
import kotlin.uuid.ExperimentalUuidApi
class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
@@ -41,6 +46,8 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
private var streamJob: Job? = null
var onUnauthorized: (() -> Unit)? = null
init {
_currentUserId.value = chatRepository.getUserId()
observeChannel()
@@ -49,6 +56,7 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
fun loadAccessibleChannels() {
_error.value = null
_currentUserId.value = chatRepository.getUserId()
viewModelScope.launch {
runCatching {
chatRepository.getAccessibleChannels()
@@ -56,11 +64,16 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
_spaces.value = data
}.onFailure { e ->
Log.e("Chat", "Failed to load spaces", e)
if (e is ResponseException && e.response.status == HttpStatusCode.Unauthorized) {
onUnauthorized?.invoke()
} else {
_error.value = "Failed to load channels: ${e.message}"
}
}
}
}
@OptIn(ExperimentalTime::class)
private fun observeChannel() {
viewModelScope.launch {
_channelId.collect { id ->
@@ -69,13 +82,40 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
_channelError.value = null
if (id != null) {
streamJob = launch {
chatRepository.messageStream(id)
chatRepository.eventStream(id)
.catch { e ->
Log.e("Chat", "Stream error", e)
if (e is ResponseException && e.response.status == HttpStatusCode.Unauthorized) {
onUnauthorized?.invoke()
} else {
_channelError.value = "Connection lost: ${e.message}"
}
.collect { message ->
_messages.update { it + message }
}
.collect { event ->
when (event) {
is ChatEvent.SendMessage -> {
_messages.update { it + event.data }
}
is ChatEvent.EditMessage -> {
_messages.update { messages ->
messages.map {
if (it.id == event.data.id) event.data.message
else it
}
}
}
is ChatEvent.MessageAppendContent -> {
_messages.update { messages ->
messages.map { msg ->
if (msg.id == event.data.id) {
msg.copy(text = msg.text + event.data.content)
} else {
msg
}
}
}
}
}
}
}
}
@@ -108,15 +148,20 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
)
}.onFailure { e ->
Log.e("Chat", "Send message error", e)
if (e is ResponseException && e.response.status == HttpStatusCode.Unauthorized) {
onUnauthorized?.invoke()
} else {
_channelError.value = "Failed to send message"
}
}
}
}
fun clearChat() {
_messages.value = emptyList()
_channelId.value = null
_currentUserId.value = null
_spaces.value = emptyList()
_error.value = null
_channelError.value = null
streamJob?.cancel()
@@ -1,3 +1,5 @@
@file:OptIn(ExperimentalUuidApi::class)
package dev.zxq5.chatapp.android.feature.chat
import androidx.compose.foundation.BorderStroke
@@ -65,6 +67,7 @@ import dev.zxq5.chatapp.android.api.model.Message
import java.text.DateFormat
import java.util.Date
import kotlin.time.ExperimentalTime
import kotlin.uuid.ExperimentalUuidApi
@Composable
fun ChatScreen(
@@ -277,7 +280,7 @@ fun MessageScreen(channelId: Long, viewModel: ChatViewModel, onBack: () -> Unit)
modifier = Modifier.weight(1f).padding(horizontal = 16.dp),
verticalArrangement = Arrangement.spacedBy(10.dp)
) {
items(messages) { message ->
items(messages, key = { it.id }) { message ->
MessageBubble(message, currentUserId)
}
item { Spacer(Modifier.height(10.dp)) }
@@ -378,7 +381,7 @@ fun MessageBubble(message: Message, currentUserId: Int?) {
horizontalAlignment = if (isMe) Alignment.End else Alignment.Start
) {
Surface(
color = if (isMe) MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f) else MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.2f),
color = if (isMe) MaterialTheme.colorScheme.surfaceVariant else MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.2f),
shape = RoundedCornerShape(
topStart = 14.dp,
topEnd = 14.dp,
@@ -388,14 +391,7 @@ fun MessageBubble(message: Message, currentUserId: Int?) {
border = border(0.5.dp, MaterialTheme.colorScheme.outlineVariant.copy(alpha = 0.5f))
) {
Column(modifier = Modifier.padding(horizontal = 11.dp, vertical = 8.dp)) {
if (!isMe) {
Text(
message.display_name?.lowercase() ?: "unknown",
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.primary.copy(alpha = 0.7f),
modifier = Modifier.padding(bottom = 2.dp)
)
}
Text(
text = message.text,
style = MaterialTheme.typography.bodyLarge,
@@ -403,10 +399,11 @@ fun MessageBubble(message: Message, currentUserId: Int?) {
)
}
}
Text(
text = time,
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.3f),
text = if (!isMe) message.display_name.lowercase() + " . " + time else time,
style = MaterialTheme.typography.labelMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f),
modifier = Modifier.padding(horizontal = 4.dp, vertical = 2.dp)
)
}
@@ -2,6 +2,7 @@ package dev.zxq5.chatapp.android.feature.settings
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import dev.zxq5.chatapp.android.api.model.InviteRequest
import dev.zxq5.chatapp.android.api.model.QrResponse
import dev.zxq5.chatapp.android.core.error.ApiResult
import dev.zxq5.chatapp.android.data.repository.SettingsRepository
@@ -27,6 +28,9 @@ class SettingsViewModel(private val settingsRepository: SettingsRepository) : Vi
private val _isSuccessState = MutableStateFlow<Map<String, Boolean>>(emptyMap())
val isSuccessState: StateFlow<Map<String, Boolean>> = _isSuccessState
private val _lastInviteCode = MutableStateFlow<String?>(null)
val lastInviteCode: StateFlow<String?> = _lastInviteCode
fun clearMessages() {
_settingsError.value = null
_totpError.value = null
@@ -40,6 +44,20 @@ class SettingsViewModel(private val settingsRepository: SettingsRepository) : Vi
}
}
fun createInvite(request: InviteRequest) {
viewModelScope.launch {
_settingsError.value = null
when (val result = settingsRepository.createInvite(request)) {
is ApiResult.Success -> {
_lastInviteCode.value = result.data
triggerSuccess("invite")
}
is ApiResult.HttpError -> _settingsError.value = result.message
is ApiResult.NetworkError -> _settingsError.value = result.message
}
}
}
fun fetchTotpStatus() {
viewModelScope.launch {
when (val result = settingsRepository.getTotpStatus()) {
@@ -23,11 +23,14 @@ import androidx.compose.foundation.text.KeyboardOptions
import androidx.compose.foundation.verticalScroll
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.filled.ArrowBack
import androidx.compose.material.icons.filled.ContentCopy
import androidx.compose.material.icons.filled.KeyboardArrowDown
import androidx.compose.material.icons.filled.KeyboardArrowUp
import androidx.compose.material3.Button
import androidx.compose.material3.ButtonDefaults
import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.DatePicker
import androidx.compose.material3.DatePickerDialog
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.HorizontalDivider
import androidx.compose.material3.Icon
@@ -37,8 +40,10 @@ import androidx.compose.material3.OutlinedTextField
import androidx.compose.material3.OutlinedTextFieldDefaults
import androidx.compose.material3.Scaffold
import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.material3.TopAppBar
import androidx.compose.material3.TopAppBarDefaults
import androidx.compose.material3.rememberDatePickerState
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState
@@ -57,9 +62,15 @@ import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import android.util.Base64
import android.graphics.BitmapFactory
import androidx.compose.ui.platform.LocalClipboardManager
import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.text.style.TextAlign
import dev.zxq5.chatapp.android.api.model.InviteRequest
import kotlin.time.Duration.Companion.days
import kotlin.time.ExperimentalTime
import kotlin.time.Instant
@OptIn(ExperimentalMaterial3Api::class)
@OptIn(ExperimentalMaterial3Api::class, ExperimentalTime::class)
@Composable
fun SettingsScreen(
viewModel: SettingsViewModel,
@@ -70,6 +81,7 @@ fun SettingsScreen(
val settingsError by viewModel.settingsError.collectAsState()
val isSuccessState by viewModel.isSuccessState.collectAsState()
val totpError by viewModel.totpError.collectAsState()
val lastInviteCode by viewModel.lastInviteCode.collectAsState()
LaunchedEffect(Unit) {
viewModel.clearMessages()
@@ -274,6 +286,120 @@ fun SettingsScreen(
}
}
SettingsSection(title = "invite") {
var inviteName by remember { mutableStateOf("") }
var maxUses by remember { mutableStateOf("1") }
val clipboardManager = LocalClipboardManager.current
var showDatePicker by remember { mutableStateOf(false) }
val datePickerState = rememberDatePickerState(
initialSelectedDateMillis = System.currentTimeMillis() + 7.days.inWholeMilliseconds
)
Text("create invite token", style = MaterialTheme.typography.bodyMedium, modifier = Modifier.padding(bottom = 8.dp))
OutlinedTextField(
value = inviteName,
onValueChange = { inviteName = it },
label = { Text("name") },
modifier = Modifier.fillMaxWidth(),
shape = RoundedCornerShape(8.dp)
)
Spacer(Modifier.height(8.dp))
OutlinedTextField(
value = maxUses,
onValueChange = { if (it.all { c -> c.isDigit() }) maxUses = it },
label = { Text("max uses") },
keyboardOptions = KeyboardOptions(keyboardType = KeyboardType.Number),
modifier = Modifier.fillMaxWidth(),
shape = RoundedCornerShape(8.dp)
)
Spacer(Modifier.height(8.dp))
OutlinedTextField(
value = datePickerState.selectedDateMillis?.let { Instant.fromEpochMilliseconds(it).toString().substringBefore("T") } ?: "",
onValueChange = {},
label = { Text("expiry date") },
readOnly = true,
trailingIcon = {
IconButton(onClick = { showDatePicker = true }) {
Icon(Icons.Default.KeyboardArrowDown, contentDescription = "Select Date")
}
},
modifier = Modifier.fillMaxWidth().clickable { showDatePicker = true },
shape = RoundedCornerShape(8.dp)
)
if (showDatePicker) {
DatePickerDialog(
onDismissRequest = { showDatePicker = false },
confirmButton = {
TextButton(onClick = { showDatePicker = false }) {
Text("ok")
}
}
) {
DatePicker(state = datePickerState)
}
}
Spacer(Modifier.height(12.dp))
SuccessButton(
onClick = {
val nowMs = System.currentTimeMillis()
val expiryMs = datePickerState.selectedDateMillis ?: (nowMs + 7.days.inWholeMilliseconds)
viewModel.createInvite(
InviteRequest(
name = inviteName,
max_uses = maxUses.toIntOrNull() ?: 1,
start_date = Instant.fromEpochMilliseconds(nowMs),
expiry_date = Instant.fromEpochMilliseconds(expiryMs)
)
)
},
label = "generate invite",
isSuccess = isSuccessState["invite"] == true,
enabled = inviteName.isNotBlank(),
modifier = Modifier.fillMaxWidth()
)
if (lastInviteCode != null) {
Spacer(Modifier.height(16.dp))
Row(
modifier = Modifier
.fillMaxWidth()
.background(MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.3f), RoundedCornerShape(8.dp))
.padding(12.dp),
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.SpaceBetween
) {
Text(
text = lastInviteCode!!,
style = MaterialTheme.typography.bodyLarge,
modifier = Modifier.weight(1f)
)
IconButton(onClick = {
clipboardManager.setText(AnnotatedString(lastInviteCode!!))
}) {
Icon(Icons.Default.ContentCopy, contentDescription = "Copy", modifier = Modifier.size(20.dp))
}
}
}
}
SettingsSection(title = "session") {
Button(
onClick = onLogout,
modifier = Modifier.fillMaxWidth(),
shape = RoundedCornerShape(8.dp),
colors = ButtonDefaults.buttonColors(containerColor = Color.White, contentColor = Color.Black)
) {
Text("logout")
}
}
SettingsSection(title = "danger zone", color = Color.Red.copy(alpha = 0.7f)) {
var deletePassword by remember { mutableStateOf("") }
var deleteTotp by remember { mutableStateOf("") }
@@ -337,18 +463,6 @@ fun SettingsScreen(
}
}
SettingsSection(title = "session") {
Spacer(Modifier.height(16.dp))
Button(
onClick = onLogout,
modifier = Modifier.fillMaxWidth(),
shape = RoundedCornerShape(8.dp),
colors = ButtonDefaults.buttonColors(containerColor = Color.White, contentColor = Color.Black)
) {
Text("logout")
}
}
if (settingsError != null) {
Text(settingsError!!, color = Color.Red, style = MaterialTheme.typography.bodySmall, modifier = Modifier.padding(top = 8.dp))
}
@@ -457,6 +571,7 @@ fun SuccessButton(
}
}
@OptIn(ExperimentalTime::class)
@Composable
fun TwoFactorSetup(
qrCodeBase64: String?,
@@ -511,15 +626,13 @@ fun TwoFactorSetup(
Text(error.lowercase(), color = Color.Red, style = MaterialTheme.typography.labelSmall, modifier = Modifier.padding(top = 8.dp))
}
Spacer(Modifier.height(24.dp))
Button(
onClick = { if (code.length == 6) onConfirm(code) },
Spacer(Modifier.height(16.dp))
SuccessButton(
onClick = { onConfirm(code) },
label = "verify and enable",
isSuccess = false, // Managed by parent
enabled = code.length == 6,
modifier = Modifier.fillMaxWidth(),
shape = RoundedCornerShape(8.dp)
) {
Text("confirm code")
}
modifier = Modifier.fillMaxWidth()
)
}
}
@@ -2,12 +2,14 @@ package dev.zxq5.chatapp.android.ui.components
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.foundation.text.KeyboardOptions
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.OutlinedTextField
import androidx.compose.material3.OutlinedTextFieldDefaults
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.text.input.KeyboardType
import androidx.compose.ui.text.input.PasswordVisualTransformation
import androidx.compose.ui.unit.dp
@@ -31,6 +33,11 @@ fun TextField(
modifier = Modifier.fillMaxWidth(),
singleLine = true,
textStyle = MaterialTheme.typography.bodyLarge,
keyboardOptions = if (isPassword) {
KeyboardOptions(keyboardType = KeyboardType.Password)
} else {
KeyboardOptions.Default
},
visualTransformation = if (isPassword) PasswordVisualTransformation() else androidx.compose.ui.text.input.VisualTransformation.None,
shape = RoundedCornerShape(8.dp),
colors = OutlinedTextFieldDefaults.colors(
@@ -40,6 +47,6 @@ fun TextField(
unfocusedBorderColor = MaterialTheme.colorScheme.outline,
focusedTextColor = MaterialTheme.colorScheme.onSurface,
unfocusedTextColor = MaterialTheme.colorScheme.onSurface
)
),
)
}
@@ -2,7 +2,7 @@ package dev.zxq5.chatapp.android.ui.theme
import androidx.compose.ui.graphics.Color
val Black = Color(0xFF0A0A0A)
val Black = Color(0xFF000000)
val DarkGrey = Color(0xFF0D0D0D)
val Grey = Color(0xFF141414)
val LightGrey = Color(0xFF1E1E1E)
+6
View File
@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="uk.co.ben_gibson.git.link.SettingsState">
<option name="host" value="e0f86390-1091-4871-8aeb-f534fbc99cf0" />
</component>
</project>
+1 -1
View File
@@ -1,7 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DataSourcePerFileMappings">
<file url="file://$PROJECT_DIR$/sql/schema.sql" value="b14acf5d-6750-469b-8aea-59c8343eb11c" />
<file url="file://$PROJECT_DIR$/sql/test.sql" value="b14acf5d-6750-469b-8aea-59c8343eb11c" />
<file url="file://$PROJECT_DIR$/src/repo/user_repo.rs" value="b14acf5d-6750-469b-8aea-59c8343eb11c" />
</component>
</project>
+2 -1
View File
@@ -1,7 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="SqlDialectMappings">
<file url="file://$PROJECT_DIR$/sql/schema.sql" dialect="PostgreSQL" />
<file url="file://$PROJECT_DIR$/migrations/20260412200102_message_id_to_uuid.sql" dialect="PostgreSQL" />
<file url="file://$PROJECT_DIR$/sql/test.sql" dialect="PostgreSQL" />
<file url="PROJECT" dialect="PostgreSQL" />
</component>
</project>
+3 -3
View File
@@ -12,7 +12,7 @@ image = "0.25.8"
jsonwebtoken = { version = "10.3.0", features = ["rust_crypto"] }
rand = "0.8"
redis = { version = "0.25.4", features = ["tokio-comp"] }
reqwest = { version = "0.12.23", features = ["json"] }
reqwest = { version = "0.12.23", features = ["json", "stream"] }
rocket = { version = "0.5.1", features = ["json", "secrets"] }
rocket_cors = "0.6.0"
rocket_db_pools = { version = "0.2.0", features = ["deadpool_redis", "sqlx_macros", "sqlx_postgres"] }
@@ -20,11 +20,11 @@ rocket_dyn_templates = { version = "0.2.0", features = ["tera"] }
serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0.145"
sha2 = "0.10.9"
sqlx = { version = "0.7.4", features = ["chrono", "macros", "postgres", "time"] }
sqlx = { version = "0.7.4", features = ["chrono", "macros", "postgres", "time", "uuid"] }
tokio = { version = "1.47.1", features = ["full"] }
totp-rs = { version = "5.7.0", features = ["gen_secret", "qr", "rand"] }
tracing = "0.1.44"
uuid = { version = "1.18.1", features = ["v4"] }
uuid = { version = "1.18.1", features = ["serde", "v4"] }
thiserror = "1.0.69"
utoipa = { version = "5.4.0", features = ["rocket_extras", "chrono"] }
clap = { version = "4.5", features = ["derive"] }
+1 -1
View File
@@ -1,7 +1,7 @@
[debug]
secret_key = "yYhvCGnRh/TrcHtB8sZqCFifrVmJxoKFLBYw/WWBZeU="
address = "0.0.0.0"
port = 8000
port = 8080
[debug.databases.postgres_db]
url = "postgresql://chatapp:chatapp@100.118.108.58:5432/chatapp_dev"
+2
View File
@@ -8,6 +8,8 @@ services:
- redis
env_file:
- .env
environment:
- RELEASE_MODE=1
redis:
container_name: chatapp_redis
@@ -0,0 +1,9 @@
ALTER TABLE attachments DROP CONSTRAINT attachments_message_id_fkey;
ALTER TABLE messages ALTER COLUMN id DROP DEFAULT;
ALTER TABLE messages ALTER COLUMN id TYPE uuid USING gen_random_uuid();
ALTER TABLE messages ALTER COLUMN id SET DEFAULT gen_random_uuid();
ALTER TABLE attachments ALTER COLUMN message_id TYPE uuid USING gen_random_uuid();
ALTER TABLE attachments ADD CONSTRAINT attachments_message_id_fkey
FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE CASCADE;
+17
View File
@@ -0,0 +1,17 @@
WITH space1 AS (
INSERT INTO spaces (name, description, owner_id)
VALUES ('general', 'Boring chat idk', 1)
RETURNING id
),
space2 AS (
INSERT INTO spaces (name, description, owner_id)
VALUES ('Gaming', 'we lose games', 1)
RETURNING id
)
INSERT INTO channels (name, description, space_id)
SELECT 'General', 'General chat', id FROM space1 UNION ALL
SELECT 'Coding', 'Coding stuff', id FROM space1 UNION ALL
SELECT 'AI', '"/ask" here pls :)', id FROM space1 UNION ALL
SELECT 'The Game', '(You lost)', id FROM space2 UNION ALL
SELECT 'Backrooms', 'Beware of Smilers', id FROM space2 UNION ALL
SELECT 'SE', 'Space/Software engineering.', id FROM space2;
+9 -16
View File
@@ -9,15 +9,7 @@ use rocket::{Shutdown, State, ___internal_EventStream as EventStream};
use sqlx::FromRow;
use tokio::select;
use tokio::sync::broadcast;
/// ---------- Rocket routes ----------
#[derive(Debug, Serialize, Deserialize, Clone, FromRow)]
pub struct ChatMsg {
pub display_name: Option<String>,
pub user_id: i64,
pub text: String,
pub timestamp: DateTime<Utc>,
}
use crate::model::event::{ChatEvent, ChatMsg};
#[post("/chat/<channel_id>", format = "json", data = "<msg>")]
pub async fn post_message(
@@ -36,24 +28,25 @@ pub async fn event_stream(
mut shutdown: Shutdown,
channel_id: i64,
) -> ApiResult<EventStream![]> {
let messages = chat.get_messages(channel_id, 100)
let messages = chat.fetch_latest_messages_desc(channel_id, 100)
.await?; // if get message returned err, inform user.
let mut rx = chat.subscribe(channel_id).await;
let id = s.uid;
Ok(EventStream! {
for msg in messages {
yield Event::json(&msg);
for msg in messages.into_iter().rev() {
// tracing::info!("sending: {:?}", serde_json::to_string(&ChatEvent::SendMessage(msg.clone())).unwrap());
yield Event::json(&ChatEvent::SendMessage(msg));
}
loop {
select!{
_ = &mut shutdown => break, // exit early on shutdown
msg = rx.recv() => match msg {
Ok(msg) => {
tracing::info!("yielding message!");
yield Event::json(&msg)
event = rx.recv() => match event {
Ok(event) => {
// tracing::info!("yielding event: {event:?}");
yield Event::json(&event)
},
Err(broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!("Receiver lagging on channel {channel_id} by {n} events",);
+27
View File
@@ -0,0 +1,27 @@
use rocket::serde::{Deserialize, Serialize};
use sqlx::FromRow;
use chrono::{DateTime, Utc};
use uuid::Uuid;
/// ---------- Rocket routes ----------
#[derive(Debug, Serialize, Deserialize, Clone, FromRow)]
pub struct ChatMsg {
pub id: Uuid,
pub display_name: Option<String>,
pub user_id: i64,
pub text: String,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", content = "data")]
pub enum ChatEvent {
SendMessage(ChatMsg),
/// for when a user explicitly edits a message
EditMessage { id: Uuid, msg: ChatMsg },
/// used for streaming content to a message
/// will not show up as edited
MessageAppendContent{ id: Uuid, content: String }
}
+1
View File
@@ -1,3 +1,4 @@
pub mod auth;
pub mod user;
pub mod space;
pub mod event;
+35 -23
View File
@@ -1,68 +1,80 @@
use crate::api::chat::ChatMsg;
use crate::model::event::ChatMsg;
use crate::repo::Repo;
use chrono::{DateTime, Utc};
use sqlx::PgPool;
use uuid::Uuid;
#[derive(Clone)]
pub struct MessageRepository {
pool: PgPool
}
impl Repo for MessageRepository {
type Target = ChatMsg;
fn new(pool: PgPool) -> Self {
impl MessageRepository {
pub(crate) fn new(pool: PgPool) -> Self {
Self { pool }
}
// TODO: caching with redis
async fn get_by_id(&self, id: i64) -> Option<Self::Target> {
async fn get_by_id(&self, id: Uuid) -> Option<ChatMsg> {
sqlx::query!(
"SELECT u.username, u.nickname, u.id as user_id, m.content, m.created_at
"SELECT m.id, u.username, u.nickname, u.id as user_id, m.content, m.created_at
FROM messages m
JOIN users u ON m.user_id = u.id
WHERE m.id = $1",
id
).fetch_optional(&self.pool).await.ok().flatten().map(|row| ChatMsg {
id: row.id,
display_name: Some(row.nickname.unwrap_or(row.username)),
user_id: row.user_id,
text: row.content,
timestamp: row.created_at,
})
}
}
impl MessageRepository {
// TODO! caching with redis
pub async fn create_new(
&self, uid: i64, channel_id: i64,
text: &str, created_at: DateTime<Utc>
) -> Result<i64, sqlx::Error> {
&self, msg: ChatMsg, channel_id: i64
) -> Result<(), sqlx::Error> {
sqlx::query!(
"INSERT INTO messages (channel_id, user_id, content, created_at)
VALUES ($1, $2, $3, $4) RETURNING id",
"INSERT INTO messages (id, channel_id, user_id, content, created_at)
VALUES ($1, $2, $3, $4, $5)",
msg.id,
channel_id,
uid,
msg.user_id,
msg.text,
msg.timestamp
).execute(&self.pool).await.map_err(|_| sqlx::Error::RowNotFound)?;
Ok(())
}
pub async fn update_text(&self, id: Uuid, text: &str) -> Result<(), sqlx::Error> {
sqlx::query!(
"UPDATE messages SET content = $1 WHERE id = $2",
text,
created_at
).fetch_optional(&self.pool).await.and_then(|row| row.map(|r| r.id).ok_or(sqlx::Error::RowNotFound))
id
).execute(&self.pool).await?;
Ok(())
}
/// TODO: caching with redis
pub async fn get_by_channel(&self, channel_id: i64, limit: usize)
-> Result<Vec<ChatMsg>, sqlx::Error> {
pub async fn get_latest_by_channel_desc(
&self, channel_id: i64, limit: usize, page: usize
) -> Result<Vec<ChatMsg>, sqlx::Error> {
sqlx::query!(
"SELECT u.username, u.nickname, u.id as user_id, m.content, m.created_at
"SELECT m.id, u.username, u.nickname, u.id as user_id, m.content, m.created_at
FROM messages m
JOIN users u ON m.user_id = u.id
WHERE m.channel_id = $1
ORDER BY m.created_at DESC LIMIT $2",
ORDER BY m.created_at DESC
LIMIT $2 OFFSET $3",
channel_id,
limit as i64
limit as i64,
page as i64
).fetch_all(&self.pool).await.map(|messages| {
messages.into_iter().rev().map(|msg| {
messages.into_iter().map(|msg| {
ChatMsg {
id: msg.id,
display_name: Some(msg.nickname.unwrap_or(msg.username)),
user_id: msg.user_id,
text: msg.content,
+104 -42
View File
@@ -1,12 +1,15 @@
use crate::api::chat::ChatMsg;
use crate::model::event::{ChatEvent, ChatMsg};
use crate::error::{ApiResult, AppError};
use crate::repo::message_repo::MessageRepository;
use crate::repo::{ChannelRepo, Repo, SpaceRepo, UserRepo};
use chrono::{DateTime, Utc};
use std::collections::HashMap;
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::sync::mpsc::channel;
use std::time::Instant;
use tokio::sync::broadcast::Sender;
use tokio::sync::{broadcast, Mutex};
use uuid::Uuid;
use crate::model::space::SpaceDto;
use crate::svc::llm_service::LlmService;
@@ -15,15 +18,13 @@ use crate::svc::llm_service::LlmService;
#[derive(Clone)]
pub struct ChatService {
users: Arc<dyn UserRepo>,
channels: Arc<dyn ChannelRepo>,
channel_repo: Arc<dyn ChannelRepo>,
spaces: Arc<dyn SpaceRepo>,
messages: MessageRepository,
llm: LlmService,
buffer_size: usize,
senders: Arc<Mutex<HashMap<i64, Sender<ChatMsg>>>>,
channels: Arc<Mutex<HashMap<i64, ChannelState>>>,
}
impl ChatService {
@@ -33,13 +34,13 @@ impl ChatService {
channels: Arc<dyn ChannelRepo>, spaces: Arc<dyn SpaceRepo>,
) -> Self {
Self {
channels,
channel_repo: channels,
spaces,
llm,
users,
messages,
buffer_size,
senders: Arc::new(Mutex::new(std::collections::HashMap::new())),
channels: Arc::new(Mutex::new(std::collections::HashMap::new())),
}
}
@@ -50,7 +51,7 @@ impl ChatService {
let mut result = Vec::new();
for space in spaces {
let channels = self.channels.get_by_space_id(space.id).await?;
let channels = self.channel_repo.get_by_space_id(space.id).await?;
result.push(SpaceDto {
channels,
id: space.id,
@@ -65,8 +66,10 @@ impl ChatService {
Ok(result)
}
pub async fn get_messages(&self, channel_id: i64, limit: usize) -> ApiResult<Vec<ChatMsg>> {
let messages = self.messages.get_by_channel(channel_id, limit).await?;
pub async fn fetch_latest_messages_desc(&self, channel_id: i64, limit: usize) -> ApiResult<Vec<ChatMsg>> {
const PAGE: usize = 0;
let messages = self.messages.get_latest_by_channel_desc(channel_id, limit, PAGE).await?;
Ok(messages)
}
@@ -113,6 +116,7 @@ impl ChatService {
.ok_or(AppError::NotFound)?;
let message = ChatMsg {
id: Uuid::new_v4(),
display_name: Some(user
.nickname.clone()
.unwrap_or_else(|| user.username.clone())),
@@ -122,66 +126,124 @@ impl ChatService {
};
self.publish(channel_id, message.clone()).await;
let _msg_id = self.messages.create_new(uid, channel_id, text, created_at).await?;
self.messages.create_new(message.clone(), channel_id).await?;
// TODO: caching w redis at repository layer
let svc_instance = self.clone();
let Some(text) = text.strip_prefix("/ask ") else {
return Ok(())
};
if !svc_instance.llm.enabled() {
return Ok(())
if !message.text.starts_with("/ask ") {
return Ok(());
}
let svc_instance = self.clone();
if !svc_instance.llm.enabled() {
return Ok(());
}
let context = self.get_history(channel_id, 25).await?;
tokio::spawn(async move {
let sender = match svc_instance.channels.lock().await.get(&channel_id) {
Some(s) => s.get_sender(),
None => return,
};
let response = svc_instance.llm
.query(&message)
.query(&message, &context, sender)
.await;
if let Ok(reply) = response {
let Ok(reply) = response else {
tracing::warn!("Error contacting LLM: {:?}", response);
return;
};
tracing::info!("LLM reply: {}", reply.text);
svc_instance.publish(channel_id, reply.clone()).await;
// TODO: cache response (or do with redis!)
if let Err(e) = svc_instance.messages
.create_new(reply.user_id, channel_id, &reply.text, reply.timestamp).await {
if let Err(e) = svc_instance.messages.create_new(reply, channel_id).await {
tracing::error!("Failed to persist LLM reply: {}", e);
}
tracing::info!("LLM reply persisted");
} else {
tracing::warn!("Error contacting LLM: {:?}", response);
}
tracing::debug!("Full LLM reply persisted");
});
Ok(())
}
async fn get_history(&self, channel_id: i64, limit: usize) -> ApiResult<Vec<ChatMsg>> {
let mut map = self.channels.lock().await;
if let Some(channel) = map.get(&channel_id) && !channel.cache.is_empty() {
Ok(channel.history().clone().into_iter().take(limit).collect())
} else {
let messages: Vec<_> = self.messages.get_latest_by_channel_desc(channel_id, limit, 0).await?.into_iter().rev().collect();
map.insert(channel_id,
ChannelState::new(self.buffer_size, Some(messages.clone().into()))
);
Ok(messages)
}
}
/// Subscribe to the specified channel.
pub async fn subscribe(&self, channel_id: i64) -> broadcast::Receiver<ChatMsg> {
let mut map = self.senders.lock().await;
let sender = map
pub async fn subscribe(&self, channel_id: i64) -> broadcast::Receiver<ChatEvent> {
let mut map = self.channels.lock().await;
let channel = map
.entry(channel_id)
.or_insert_with(|| broadcast::channel::<ChatMsg>(self.buffer_size).0);
sender.subscribe()
.or_insert_with(|| ChannelState::new(self.buffer_size, None));
channel.subscribe()
}
// Private helper methods
/// Publish a message to the specified channel.
async fn publish(&self, channel_id: i64, msg: ChatMsg) {
let mut map = self.senders.lock().await;
let sender = map
let mut map = self.channels.lock().await;
let channel = map
.entry(channel_id)
.or_insert_with(|| broadcast::channel::<ChatMsg>(self.buffer_size).0);
let _ = sender.send(msg);
.or_insert_with(|| ChannelState::new(self.buffer_size, None));
channel.send(msg);
}
}
#[derive(Clone)]
pub struct ChannelState {
sender: Sender<ChatEvent>,
cache: VecDeque<ChatMsg>,
last_updated: Instant,
}
impl ChannelState {
const MAX_HISTORY_SIZE: usize = 100;
#[must_use]
pub fn new(buffer_size: usize, history: Option<VecDeque<ChatMsg>>) -> Self {
Self {
sender: broadcast::channel(buffer_size).0,
cache: history.unwrap_or_default(),
last_updated: Instant::now(),
}
}
pub fn history(&self) -> &VecDeque<ChatMsg> {
&self.cache
}
pub fn get_sender(&self) -> Sender<ChatEvent> {
self.sender.clone()
}
#[must_use]
pub fn send(&mut self, msg: ChatMsg) {
while self.cache.len() >= Self::MAX_HISTORY_SIZE {
self.cache.pop_front();
}
self.cache.push_back(msg.clone());
if self.sender.send(ChatEvent::SendMessage(msg)).is_err() {
tracing::warn!("Sent message to channel with no subscribers");
}
}
#[must_use]
pub fn subscribe(&self) -> broadcast::Receiver<ChatEvent> {
self.sender.subscribe()
}
}
+119 -41
View File
@@ -1,3 +1,13 @@
use std::env;
use std::sync::LazyLock;
use futures_util::StreamExt;
use serde::{Deserialize, Serialize};
use tokio::sync::broadcast::Sender;
use uuid::Uuid;
use crate::model::event::{ChatEvent, ChatMsg};
use crate::error::{ApiResult, AppError};
#[derive(Clone)]
pub struct LlmService;
@@ -15,71 +25,139 @@ impl LlmService {
LMSTUDIO_URL.is_some()
}
pub async fn query(&self, message: &ChatMsg) -> ApiResult<ChatMsg> {
pub async fn query(&self, message: &ChatMsg, context: &[ChatMsg], sender: Sender<ChatEvent>) -> ApiResult<ChatMsg> {
let Some(url) = LMSTUDIO_URL.clone() else {
return Err(AppError::internal("AI not enabled!"))
};
let model = LMSTUDIO_MODEL.clone().unwrap_or_else(|| "gpt-oss-20b".into());
let client = reqwest::Client::new();
// Build the request body
let payload = LlmRequest {
model, // whatever model you run locally
messages: vec![Message {
role: "user".into(),
content: message.text.clone(),
}],
let reply_id = Uuid::new_v4();
let timestamp = chrono::Utc::now();
let mut reply = ChatMsg {
id: reply_id,
display_name: Some("llm".into()),
user_id: 0,
text: String::new(),
timestamp,
};
let _ = sender.send(ChatEvent::SendMessage(reply.clone()));
// POST to lmstudio (default 127.0.0.1:1234)
let resp = client
let mut messages: Vec<Message> = Vec::new();
let system_prompt = format!(
"You are a helpful assistant in a group chat. \
You are talking to '{}'. \
Keep responses concise and conversational.",
message.display_name.as_deref().unwrap_or("unknown"),
);
messages.push(Message { role: "system".into(), content: system_prompt });
for msg in context {
let role = if msg.user_id == 0 {
"assistant" // your LLM user_id convention
} else {
"user"
};
messages.push(Message {
role: role.into(),
content: format!(
"{}: {}",
msg.display_name.as_deref().unwrap_or("unknown"),
msg.text
),
});
}
messages.push(Message {
role: "user".into(),
content: format!(
"{}: {}",
message.display_name.as_deref().unwrap_or("unknown"),
message.text.trim_start_matches("/ask ") // strip the command prefix
),
});
let Ok(resp) = reqwest::Client::new()
.post(url)
.json(&payload)
.json(&LlmRequest {
think: false,
model: LMSTUDIO_MODEL
.clone()
.unwrap_or_else(|| "gpt-oss-20b".into()),
messages,
stream: true,
})
.send()
.await
.map_err(|_| AppError::internal("Failed to make request to LLM server"))?;
else {
tracing::warn!("Failed to reach LLM");
let _ = sender.send(ChatEvent::MessageAppendContent {
id: reply_id,
content: String::from("I'm not available right now. Please try again later.")
});
return Err(AppError::internal("Failed to reach LLM"));
};
// The API returns a JSON with `choices[].message.content`
#[derive(Deserialize)]
struct LlmResponse {
choices: Vec<Choice>,
}
#[derive(Deserialize)]
struct Choice {
message: Message,
let mut full_text = String::new();
let mut buffer = String::new();
let mut stream = resp.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|_| AppError::internal("Stream error"))?;
buffer.push_str(&String::from_utf8_lossy(&chunk));
while let Some(pos) = buffer.find('\n') {
let line = buffer[..pos].trim().to_string();
buffer = buffer[pos + 1..].to_string();
if line == "data: [DONE]" {
break;
}
let llm_resp: LlmResponse = resp
.json()
.await
.map_err(|_| AppError::internal("Failed to parse LLM response"))?;
if let Some(json) = line.strip_prefix("data: ")
&& let Ok(parsed) = serde_json::from_str::<StreamingResponse>(json)
&& let Some(content) = parsed.choices[0].delta.content.as_ref() {
full_text.push_str(content);
let _ = sender.send(ChatEvent::MessageAppendContent {
id: reply_id,
content: content.clone(),
});
}
}
}
Ok(ChatMsg {
display_name: Some(String::from("llm")),
user_id: 0,
text: llm_resp.choices[0].message.content.clone(),
timestamp: chrono::Utc::now(),
})
reply.text = full_text;
Ok(reply)
}
}
use std::env;
use std::sync::LazyLock;
// src/llm.rs
use serde::{Deserialize, Serialize};
use crate::api::chat::ChatMsg;
use crate::error::{ApiResult, AppError};
use crate::svc::chat_svc::ChatService;
#[derive(Serialize)]
struct LlmRequest {
model: String,
messages: Vec<Message>,
stream: bool,
think: bool,
}
#[derive(Deserialize)]
struct StreamingResponse {
choices: Vec<StreamingChoice>,
}
#[derive(Deserialize)]
struct StreamingChoice {
delta: Delta,
#[serde(default)]
finish_reason: Option<String>,
}
#[derive(Deserialize)]
struct Delta {
#[serde(default)]
content: Option<String>,
}
#[derive(Serialize, Deserialize)]
struct Message {
role: String, // "user" or "assistant"
+1
View File
@@ -4,3 +4,4 @@ pub mod settings_svc;
pub mod user_svc;
pub mod access_token_svc;
pub mod llm_service;
pub mod relationship_svc;
+6
View File
@@ -0,0 +1,6 @@
pub struct RelationshipService {}
impl RelationshipService {
pub fn new() -> Self { Self {} }
}
+228
View File
@@ -0,0 +1,228 @@
import argparse
import json
import threading
import time
from dataclasses import dataclass
from getpass import getpass
from typing import List
import requests
BASE_URL = "http://localhost:8000"
@dataclass
class AuthResult:
token: str
def signup(session: requests.Session, email: str, username: str, password: str, access_token: str) -> None:
url = f"{BASE_URL}/api/signup"
payload = {
"email": email,
"username": username,
"password": password,
"access_token": access_token,
}
resp = session.post(url, json=payload, timeout=10)
resp.raise_for_status()
def login(session: requests.Session, username: str, password: str) -> AuthResult:
url = f"{BASE_URL}/api/login"
payload = {
"username": username,
"password": password,
}
resp = session.post(url, json=payload, timeout=10)
resp.raise_for_status()
data = resp.json()
token = data.get("token")
if not token:
raise RuntimeError(f"Login response did not contain token: {data}")
return AuthResult(token=token)
def post_message(session: requests.Session, channel_id: int, token: str, text: str, display_name: str, user_id: int) -> None:
url = f"{BASE_URL}/api/chat/{channel_id}"
payload = {
"display_name": display_name,
"user_id": user_id,
"text": text,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
}
headers = {"Authorization": f"Bearer {token}"}
resp = session.post(url, json=payload, headers=headers, timeout=10)
resp.raise_for_status()
def read_sse_messages(
session: requests.Session,
channel_id: int,
token: str,
expected_count: int,
timeout_s: int,
capture_live_messages: threading.Event,
) -> List[dict]:
url = f"{BASE_URL}/api/events/{channel_id}"
headers = {
"Authorization": f"Bearer {token}",
"Accept": "text/event-stream",
}
received: List[dict] = []
deadline = time.monotonic() + timeout_s
try:
with session.get(url, headers=headers, stream=True, timeout=(5, timeout_s)) as resp:
resp.raise_for_status()
event_data_lines: List[str] = []
for raw_line in resp.iter_lines(decode_unicode=True):
if time.monotonic() > deadline:
break
if raw_line is None:
continue
line = raw_line.strip()
if not line:
if event_data_lines:
joined = "\n".join(event_data_lines)
event_data_lines.clear()
try:
obj = json.loads(joined)
except json.JSONDecodeError:
continue
if capture_live_messages.is_set():
received.append(obj)
if expected_count > 0 and len(received) >= expected_count:
break
else:
print(f"Discarding message: {obj}")
continue
if line.startswith("data:"):
event_data_lines.append(line[len("data:"):].strip())
except requests.exceptions.Timeout:
print("Timeout while reading SSE.")
except requests.exceptions.RequestException as exc:
print(f"Error reading SSE: {exc}")
return received
def prompt_nonempty(label: str, secret: bool = False) -> str:
while True:
value = getpass(label) if secret else input(label)
value = value.strip()
if value:
return value
print("Please enter a value.")
def main() -> int:
parser = argparse.ArgumentParser(description="Chat integration test against localhost:8000")
parser.add_argument("--existing-account", action="store_true",
help="Skip signup and only log in with an existing account")
parser.add_argument("--email", default=None)
parser.add_argument("--username", default=None)
parser.add_argument("--password", default=None)
parser.add_argument("--access-token", default=None,
help="Required only for signup mode")
parser.add_argument("--channel-id", type=int, default=1)
parser.add_argument("--message-count", type=int, default=5)
parser.add_argument("--timeout", type=int, default=15)
args = parser.parse_args()
session = requests.Session()
if args.existing_account:
username = args.username or prompt_nonempty("Username: ")
password = args.password or prompt_nonempty("Password: ", secret=True)
else:
email = args.email or prompt_nonempty("Email: ")
username = args.username or prompt_nonempty("Username: ")
password = args.password or prompt_nonempty("Password: ", secret=True)
access_token = args.access_token or prompt_nonempty("Access token: ")
print("[1/5] Signing up...")
try:
signup(session, email, username, password, access_token)
print(" signup ok")
except requests.HTTPError as e:
print(f" signup returned HTTP error: {e}")
print(" continuing to login...")
print("[2/5] Logging in...")
auth = login(session, username, password)
print(" login ok")
print(f" token: {auth.token[:12]}...")
print("[3/5] Opening event stream...")
received_messages: List[dict] = []
capture_live_messages = threading.Event()
stream_done = threading.Event()
def stream_reader() -> None:
nonlocal received_messages
try:
received_messages = read_sse_messages(
session=session,
channel_id=args.channel_id,
token=auth.token,
expected_count=args.message_count,
timeout_s=args.timeout,
capture_live_messages=capture_live_messages,
)
finally:
stream_done.set()
t = threading.Thread(target=stream_reader, daemon=True)
t.start()
# Give the server time to flush backlog on this same stream connection.
time.sleep(1.0)
print("[4/5] Starting to capture live messages and sending messages...")
capture_live_messages.set()
sent_texts = [f"Message {i}" for i in range(args.message_count)]
for i, text in enumerate(sent_texts):
post_message(
session=session,
channel_id=args.channel_id,
token=auth.token,
text=text,
display_name=username,
user_id=1,
)
print(f" sent {i + 1}/{args.message_count}: {text}")
time.sleep(0.1)
stream_done.wait(timeout=args.timeout)
t.join(timeout=1)
print("\nReceived messages:")
for i, msg in enumerate(received_messages, start=1):
print(f" {i}. {msg}")
received_texts = [m.get("text") for m in received_messages if isinstance(m, dict)]
for text in sent_texts:
if text not in received_texts:
print(f"\nFAIL: missing message: {text}")
return 1
print("\nPASS: login and message delivery test succeeded.")
return 0
if __name__ == "__main__":
raise SystemExit(main())