diff --git a/android/app/src/main/AndroidManifest.xml b/android/app/src/main/AndroidManifest.xml index ef655b7..656b73a 100644 --- a/android/app/src/main/AndroidManifest.xml +++ b/android/app/src/main/AndroidManifest.xml @@ -1,6 +1,6 @@ + xmlns:tools="http://tools.android.com/tools"> diff --git a/android/app/src/main/java/dev/zxq5/chatapp/android/MainActivity.kt b/android/app/src/main/java/dev/zxq5/chatapp/android/MainActivity.kt index 89d75cb..4bcf910 100644 --- a/android/app/src/main/java/dev/zxq5/chatapp/android/MainActivity.kt +++ b/android/app/src/main/java/dev/zxq5/chatapp/android/MainActivity.kt @@ -1,9 +1,13 @@ package dev.zxq5.chatapp.android +import android.Manifest +import android.os.Build import android.os.Bundle import androidx.activity.ComponentActivity +import androidx.activity.compose.rememberLauncherForActivityResult import androidx.activity.compose.setContent import androidx.activity.enableEdgeToEdge +import androidx.activity.result.contract.ActivityResultContracts import androidx.compose.foundation.border import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.fillMaxSize @@ -62,9 +66,22 @@ class MainActivity : ComponentActivity() { val currentScreen by chatViewModel.currentScreen.collectAsState() val selectedChannelId by chatViewModel.channelId.collectAsState() + // Permission request launcher + val launcher = rememberLauncherForActivityResult( + contract = ActivityResultContracts.RequestPermission(), + onResult = { isGranted -> + if (isGranted && authState == AuthState.Authenticated) { + MessageStreamService.start(this@MainActivity) + } + } + ) + LaunchedEffect(authState) { when (authState) { AuthState.Authenticated -> { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) { + launcher.launch(Manifest.permission.POST_NOTIFICATIONS) + } MessageStreamService.start(this@MainActivity) chatViewModel.loadAccessibleChannels() } diff --git a/android/app/src/main/java/dev/zxq5/chatapp/android/api/ChatClient.kt b/android/app/src/main/java/dev/zxq5/chatapp/android/api/ChatClient.kt index 6319963..71a2d3c 100644 --- a/android/app/src/main/java/dev/zxq5/chatapp/android/api/ChatClient.kt +++ b/android/app/src/main/java/dev/zxq5/chatapp/android/api/ChatClient.kt @@ -1,7 +1,9 @@ +@file:OptIn(ExperimentalUuidApi::class) + package dev.zxq5.chatapp.android.api import dev.zxq5.chatapp.android.BuildConfig.BASE_URL -import dev.zxq5.chatapp.android.api.model.Message +import dev.zxq5.chatapp.android.api.model.ChatEvent import dev.zxq5.chatapp.android.api.model.SendMessage import dev.zxq5.chatapp.android.api.model.SpaceDto import io.ktor.client.HttpClient @@ -25,6 +27,8 @@ import kotlinx.coroutines.flow.flow import kotlinx.serialization.json.Json import kotlin.time.Clock import kotlin.time.ExperimentalTime +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid class ChatClient(private val token: String) { @@ -45,18 +49,18 @@ class ChatClient(private val token: String) { suspend fun sendMessage(channelId: Long, userId: Int, text: String) { http.post("${BASE_URL}/api/chat/$channelId") { contentType(ContentType.Application.Json) - setBody(SendMessage(user_id = userId, text = text, timestamp = Clock.System.now())) + setBody(SendMessage(id = Uuid.random(), user_id = userId, text = text, timestamp = Clock.System.now())) } } - fun messageStream(channelId: Long): Flow = flow { + fun eventStream(channelId: Long): Flow = flow { http.prepareGet("${BASE_URL}/api/events/$channelId").execute { response -> val channel = response.bodyAsChannel() while (!channel.isClosedForRead) { val line = channel.readLine() ?: break if (line.startsWith("data:")) { val json = line.removePrefix("data:").trim() - runCatching { Json.decodeFromString(json) } + runCatching { Json.decodeFromString(json) } .onSuccess { emit(it) } } } diff --git a/android/app/src/main/java/dev/zxq5/chatapp/android/api/model/Event.kt b/android/app/src/main/java/dev/zxq5/chatapp/android/api/model/Event.kt new file mode 100644 index 0000000..16cb061 --- /dev/null +++ b/android/app/src/main/java/dev/zxq5/chatapp/android/api/model/Event.kt @@ -0,0 +1,60 @@ +@file:OptIn(ExperimentalUuidApi::class, ExperimentalTime::class) + +package dev.zxq5.chatapp.android.api.model + +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerialName +import kotlinx.serialization.json.JsonClassDiscriminator +import kotlin.time.ExperimentalTime +import kotlin.time.Instant +import kotlinx.serialization.Serializable +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +@OptIn(ExperimentalSerializationApi::class) +@Serializable +@JsonClassDiscriminator("type") +sealed class ChatEvent { + + @Serializable + @SerialName("SendMessage") + data class SendMessage( + val data: Message + ) : ChatEvent() + + @Serializable + @SerialName("EditMessage") + data class EditMessage( + val data: EditMessageContent + ) : ChatEvent() + + @Serializable + @SerialName("MessageAppendContent") + data class MessageAppendContent( + val data: AppendContent + ) : ChatEvent() +} + +// tuple variants like (i64, ChatMsg) and (i64, String) +// need wrapper classes since kotlinx can't deserialise +// bare JSON arrays into data classes directly +@Serializable +data class EditMessageContent( + val id: Uuid, + val message: Message +) + +@Serializable +data class AppendContent ( + val id: Uuid, + val content: String +) + +@Serializable +data class Message ( + val id: Uuid, + val user_id: Int, + val display_name: String, + val text: String, + val timestamp: Instant +) \ No newline at end of file diff --git a/android/app/src/main/java/dev/zxq5/chatapp/android/api/model/Message.kt b/android/app/src/main/java/dev/zxq5/chatapp/android/api/model/Message.kt deleted file mode 100644 index 8dada22..0000000 --- a/android/app/src/main/java/dev/zxq5/chatapp/android/api/model/Message.kt +++ /dev/null @@ -1,13 +0,0 @@ -package dev.zxq5.chatapp.android.api.model - -import kotlinx.serialization.Serializable -import kotlin.time.ExperimentalTime -import kotlin.time.Instant - -@Serializable -data class Message @OptIn(ExperimentalTime::class) constructor( - val user_id: Int, - val display_name: String, - val text: String, - val timestamp: Instant -) \ No newline at end of file diff --git a/android/app/src/main/java/dev/zxq5/chatapp/android/api/model/SendMessage.kt b/android/app/src/main/java/dev/zxq5/chatapp/android/api/model/SendMessage.kt index efa60ea..73255af 100644 --- a/android/app/src/main/java/dev/zxq5/chatapp/android/api/model/SendMessage.kt +++ b/android/app/src/main/java/dev/zxq5/chatapp/android/api/model/SendMessage.kt @@ -3,9 +3,12 @@ package dev.zxq5.chatapp.android.api.model import kotlinx.serialization.Serializable import kotlin.time.ExperimentalTime import kotlin.time.Instant +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid @Serializable -data class SendMessage @OptIn(ExperimentalTime::class) constructor( +data class SendMessage @OptIn(ExperimentalTime::class, ExperimentalUuidApi::class) constructor( + val id: Uuid, val user_id: Int, val text: String, val timestamp: Instant diff --git a/android/app/src/main/java/dev/zxq5/chatapp/android/core/service/MessageStreamService.kt b/android/app/src/main/java/dev/zxq5/chatapp/android/core/service/MessageStreamService.kt index 829057e..f227a7d 100644 --- a/android/app/src/main/java/dev/zxq5/chatapp/android/core/service/MessageStreamService.kt +++ b/android/app/src/main/java/dev/zxq5/chatapp/android/core/service/MessageStreamService.kt @@ -6,6 +6,7 @@ import android.content.Intent import android.os.IBinder import android.util.Log import dev.zxq5.chatapp.android.ChatApplication +import dev.zxq5.chatapp.android.api.model.ChatEvent import dev.zxq5.chatapp.android.data.repository.ChatRepository import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers @@ -37,8 +38,6 @@ class MessageStreamService : Service() { fun start(context: Context) { val intent = Intent(context, MessageStreamService::class.java) - // Use startService to avoid the requirement for a persistent notification. - // This also prevents ForegroundServiceDidNotStartInTimeException. context.startService(intent) } @@ -64,17 +63,20 @@ class MessageStreamService : Service() { if (channelId == null) return currentStreamJob = serviceScope.launch { - chatRepository.messageStream(channelId) + chatRepository.eventStream(channelId) .catch { e -> Log.e("Service", "Stream error", e) } - .collect { message -> + .collect { event -> // Only show notification when an event (new message) is received // and the app is not in the foreground on this channel. if (!ChatApplication.AppState.isInForeground || activeChannelId != channelId) { - notificationService.showMessageNotification( - conversationId = channelId.toString(), - senderName = message.display_name, - messagePreview = message.text - ) + when (event) { + is ChatEvent.SendMessage -> notificationService.showMessageNotification( + conversationId = channelId.toString(), + senderName = event.data.display_name, + messagePreview = event.data.text + ) + else -> {} + } } } } diff --git a/android/app/src/main/java/dev/zxq5/chatapp/android/data/repository/ChatRepository.kt b/android/app/src/main/java/dev/zxq5/chatapp/android/data/repository/ChatRepository.kt index 23de5bb..369eb91 100644 --- a/android/app/src/main/java/dev/zxq5/chatapp/android/data/repository/ChatRepository.kt +++ b/android/app/src/main/java/dev/zxq5/chatapp/android/data/repository/ChatRepository.kt @@ -1,6 +1,7 @@ package dev.zxq5.chatapp.android.data.repository import dev.zxq5.chatapp.android.api.ChatClient +import dev.zxq5.chatapp.android.api.model.ChatEvent import dev.zxq5.chatapp.android.core.data.TokenStore import dev.zxq5.chatapp.android.api.model.Message import dev.zxq5.chatapp.android.api.model.SpaceDto @@ -43,8 +44,8 @@ class ChatRepository(private val tokenStore: TokenStore) { getChatClient()?.sendMessage(channelId, userId, text) } - fun messageStream(channelId: Long): Flow { + fun eventStream(channelId: Long): Flow { _lastActiveChannel = channelId - return getChatClient()?.messageStream(channelId) ?: emptyFlow() + return getChatClient()?.eventStream(channelId) ?: emptyFlow() } } diff --git a/android/app/src/main/java/dev/zxq5/chatapp/android/feature/chat/ChatViewModel.kt b/android/app/src/main/java/dev/zxq5/chatapp/android/feature/chat/ChatViewModel.kt index 779f94a..e454306 100644 --- a/android/app/src/main/java/dev/zxq5/chatapp/android/feature/chat/ChatViewModel.kt +++ b/android/app/src/main/java/dev/zxq5/chatapp/android/feature/chat/ChatViewModel.kt @@ -1,12 +1,13 @@ +@file:OptIn(ExperimentalUuidApi::class) + package dev.zxq5.chatapp.android.feature.chat import android.util.Log import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope -import dev.zxq5.chatapp.android.api.model.Channel +import dev.zxq5.chatapp.android.api.model.ChatEvent import dev.zxq5.chatapp.android.data.repository.ChatRepository import dev.zxq5.chatapp.android.api.model.Message -import dev.zxq5.chatapp.android.api.model.Space import dev.zxq5.chatapp.android.api.model.SpaceDto import dev.zxq5.chatapp.android.core.service.MessageStreamService import io.ktor.client.plugins.ResponseException @@ -17,6 +18,8 @@ import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.catch import kotlinx.coroutines.flow.update import kotlinx.coroutines.launch +import kotlin.time.ExperimentalTime +import kotlin.uuid.ExperimentalUuidApi class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() { @@ -70,6 +73,7 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() { } } + @OptIn(ExperimentalTime::class) private fun observeChannel() { viewModelScope.launch { _channelId.collect { id -> @@ -78,7 +82,7 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() { _channelError.value = null if (id != null) { streamJob = launch { - chatRepository.messageStream(id) + chatRepository.eventStream(id) .catch { e -> Log.e("Chat", "Stream error", e) if (e is ResponseException && e.response.status == HttpStatusCode.Unauthorized) { @@ -87,8 +91,31 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() { _channelError.value = "Connection lost: ${e.message}" } } - .collect { message -> - _messages.update { it + message } + .collect { event -> + when (event) { + is ChatEvent.SendMessage -> { + _messages.update { it + event.data } + } + is ChatEvent.EditMessage -> { + _messages.update { messages -> + messages.map { + if (it.id == event.data.id) event.data.message + else it + } + } + } + is ChatEvent.MessageAppendContent -> { + _messages.update { messages -> + messages.map { msg -> + if (msg.id == event.data.id) { + msg.copy(text = msg.text + event.data.content) + } else { + msg + } + } + } + } + } } } } diff --git a/android/app/src/main/java/dev/zxq5/chatapp/android/feature/chat/chat.kt b/android/app/src/main/java/dev/zxq5/chatapp/android/feature/chat/chat.kt index df66fde..b412748 100644 --- a/android/app/src/main/java/dev/zxq5/chatapp/android/feature/chat/chat.kt +++ b/android/app/src/main/java/dev/zxq5/chatapp/android/feature/chat/chat.kt @@ -1,3 +1,5 @@ +@file:OptIn(ExperimentalUuidApi::class) + package dev.zxq5.chatapp.android.feature.chat import androidx.compose.foundation.BorderStroke @@ -65,6 +67,7 @@ import dev.zxq5.chatapp.android.api.model.Message import java.text.DateFormat import java.util.Date import kotlin.time.ExperimentalTime +import kotlin.uuid.ExperimentalUuidApi @Composable fun ChatScreen( @@ -277,7 +280,7 @@ fun MessageScreen(channelId: Long, viewModel: ChatViewModel, onBack: () -> Unit) modifier = Modifier.weight(1f).padding(horizontal = 16.dp), verticalArrangement = Arrangement.spacedBy(10.dp) ) { - items(messages) { message -> + items(messages, key = { it.id }) { message -> MessageBubble(message, currentUserId) } item { Spacer(Modifier.height(10.dp)) } @@ -378,7 +381,7 @@ fun MessageBubble(message: Message, currentUserId: Int?) { horizontalAlignment = if (isMe) Alignment.End else Alignment.Start ) { Surface( - color = if (isMe) MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f) else MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.2f), + color = if (isMe) MaterialTheme.colorScheme.surfaceVariant else MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.2f), shape = RoundedCornerShape( topStart = 14.dp, topEnd = 14.dp, @@ -388,14 +391,7 @@ fun MessageBubble(message: Message, currentUserId: Int?) { border = border(0.5.dp, MaterialTheme.colorScheme.outlineVariant.copy(alpha = 0.5f)) ) { Column(modifier = Modifier.padding(horizontal = 11.dp, vertical = 8.dp)) { - if (!isMe) { - Text( - message.display_name?.lowercase() ?: "unknown", - style = MaterialTheme.typography.labelSmall, - color = MaterialTheme.colorScheme.primary.copy(alpha = 0.7f), - modifier = Modifier.padding(bottom = 2.dp) - ) - } + Text( text = message.text, style = MaterialTheme.typography.bodyLarge, @@ -403,10 +399,11 @@ fun MessageBubble(message: Message, currentUserId: Int?) { ) } } + Text( - text = time, - style = MaterialTheme.typography.labelSmall, - color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.3f), + text = if (!isMe) message.display_name.lowercase() + " . " + time else time, + style = MaterialTheme.typography.labelMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f), modifier = Modifier.padding(horizontal = 4.dp, vertical = 2.dp) ) } diff --git a/android/app/src/main/java/dev/zxq5/chatapp/android/ui/theme/Color.kt b/android/app/src/main/java/dev/zxq5/chatapp/android/ui/theme/Color.kt index fec32eb..2adf15f 100644 --- a/android/app/src/main/java/dev/zxq5/chatapp/android/ui/theme/Color.kt +++ b/android/app/src/main/java/dev/zxq5/chatapp/android/ui/theme/Color.kt @@ -2,7 +2,7 @@ package dev.zxq5.chatapp.android.ui.theme import androidx.compose.ui.graphics.Color -val Black = Color(0xFF0A0A0A) +val Black = Color(0xFF000000) val DarkGrey = Color(0xFF0D0D0D) val Grey = Color(0xFF141414) val LightGrey = Color(0xFF1E1E1E) diff --git a/backend/.idea/GitLink.xml b/backend/.idea/GitLink.xml new file mode 100644 index 0000000..5143819 --- /dev/null +++ b/backend/.idea/GitLink.xml @@ -0,0 +1,6 @@ + + + + + \ No newline at end of file diff --git a/backend/.idea/data_source_mapping.xml b/backend/.idea/data_source_mapping.xml index c63b183..1f991b4 100644 --- a/backend/.idea/data_source_mapping.xml +++ b/backend/.idea/data_source_mapping.xml @@ -1,7 +1,7 @@ - + \ No newline at end of file diff --git a/backend/.idea/sqldialects.xml b/backend/.idea/sqldialects.xml index a14b701..dfa2cfa 100644 --- a/backend/.idea/sqldialects.xml +++ b/backend/.idea/sqldialects.xml @@ -1,7 +1,8 @@ - + + \ No newline at end of file diff --git a/backend/Cargo.toml b/backend/Cargo.toml index c2318ee..6d1326d 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -12,7 +12,7 @@ image = "0.25.8" jsonwebtoken = { version = "10.3.0", features = ["rust_crypto"] } rand = "0.8" redis = { version = "0.25.4", features = ["tokio-comp"] } -reqwest = { version = "0.12.23", features = ["json"] } +reqwest = { version = "0.12.23", features = ["json", "stream"] } rocket = { version = "0.5.1", features = ["json", "secrets"] } rocket_cors = "0.6.0" rocket_db_pools = { version = "0.2.0", features = ["deadpool_redis", "sqlx_macros", "sqlx_postgres"] } @@ -20,11 +20,11 @@ rocket_dyn_templates = { version = "0.2.0", features = ["tera"] } serde = { version = "1.0.228", features = ["derive"] } serde_json = "1.0.145" sha2 = "0.10.9" -sqlx = { version = "0.7.4", features = ["chrono", "macros", "postgres", "time"] } +sqlx = { version = "0.7.4", features = ["chrono", "macros", "postgres", "time", "uuid"] } tokio = { version = "1.47.1", features = ["full"] } totp-rs = { version = "5.7.0", features = ["gen_secret", "qr", "rand"] } tracing = "0.1.44" -uuid = { version = "1.18.1", features = ["v4"] } +uuid = { version = "1.18.1", features = ["serde", "v4"] } thiserror = "1.0.69" utoipa = { version = "5.4.0", features = ["rocket_extras", "chrono"] } clap = { version = "4.5", features = ["derive"] } diff --git a/backend/Rocket.toml b/backend/Rocket.toml index a6d4565..920798e 100644 --- a/backend/Rocket.toml +++ b/backend/Rocket.toml @@ -1,7 +1,7 @@ [debug] secret_key = "yYhvCGnRh/TrcHtB8sZqCFifrVmJxoKFLBYw/WWBZeU=" address = "0.0.0.0" -port = 8000 +port = 8080 [debug.databases.postgres_db] url = "postgresql://chatapp:chatapp@100.118.108.58:5432/chatapp_dev" diff --git a/backend/migrations/20260412200102_message_id_to_uuid.sql b/backend/migrations/20260412200102_message_id_to_uuid.sql new file mode 100644 index 0000000..e0fa9d5 --- /dev/null +++ b/backend/migrations/20260412200102_message_id_to_uuid.sql @@ -0,0 +1,9 @@ +ALTER TABLE attachments DROP CONSTRAINT attachments_message_id_fkey; + +ALTER TABLE messages ALTER COLUMN id DROP DEFAULT; +ALTER TABLE messages ALTER COLUMN id TYPE uuid USING gen_random_uuid(); +ALTER TABLE messages ALTER COLUMN id SET DEFAULT gen_random_uuid(); + +ALTER TABLE attachments ALTER COLUMN message_id TYPE uuid USING gen_random_uuid(); +ALTER TABLE attachments ADD CONSTRAINT attachments_message_id_fkey + FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE CASCADE; diff --git a/backend/sql/test.sql b/backend/sql/test.sql index e69de29..27e007f 100644 --- a/backend/sql/test.sql +++ b/backend/sql/test.sql @@ -0,0 +1,17 @@ +WITH space1 AS ( + INSERT INTO spaces (name, description, owner_id) + VALUES ('general', 'Boring chat idk', 1) + RETURNING id +), + space2 AS ( + INSERT INTO spaces (name, description, owner_id) + VALUES ('Gaming', 'we lose games', 1) + RETURNING id + ) +INSERT INTO channels (name, description, space_id) +SELECT 'General', 'General chat', id FROM space1 UNION ALL +SELECT 'Coding', 'Coding stuff', id FROM space1 UNION ALL +SELECT 'AI', '"/ask" here pls :)', id FROM space1 UNION ALL +SELECT 'The Game', '(You lost)', id FROM space2 UNION ALL +SELECT 'Backrooms', 'Beware of Smilers', id FROM space2 UNION ALL +SELECT 'SE', 'Space/Software engineering.', id FROM space2; \ No newline at end of file diff --git a/backend/src/api/chat.rs b/backend/src/api/chat.rs index 63cda73..9e37ef4 100644 --- a/backend/src/api/chat.rs +++ b/backend/src/api/chat.rs @@ -9,15 +9,7 @@ use rocket::{Shutdown, State, ___internal_EventStream as EventStream}; use sqlx::FromRow; use tokio::select; use tokio::sync::broadcast; - -/// ---------- Rocket routes ---------- -#[derive(Debug, Serialize, Deserialize, Clone, FromRow)] -pub struct ChatMsg { - pub display_name: Option, - pub user_id: i64, - pub text: String, - pub timestamp: DateTime, -} +use crate::model::event::{ChatEvent, ChatMsg}; #[post("/chat/", format = "json", data = "")] pub async fn post_message( @@ -36,24 +28,25 @@ pub async fn event_stream( mut shutdown: Shutdown, channel_id: i64, ) -> ApiResult { - let messages = chat.get_messages(channel_id, 100) + let messages = chat.fetch_latest_messages_desc(channel_id, 100) .await?; // if get message returned err, inform user. let mut rx = chat.subscribe(channel_id).await; let id = s.uid; Ok(EventStream! { - for msg in messages { - yield Event::json(&msg); + for msg in messages.into_iter().rev() { + // tracing::info!("sending: {:?}", serde_json::to_string(&ChatEvent::SendMessage(msg.clone())).unwrap()); + yield Event::json(&ChatEvent::SendMessage(msg)); } loop { select!{ _ = &mut shutdown => break, // exit early on shutdown - msg = rx.recv() => match msg { - Ok(msg) => { - tracing::info!("yielding message!"); - yield Event::json(&msg) + event = rx.recv() => match event { + Ok(event) => { + // tracing::info!("yielding event: {event:?}"); + yield Event::json(&event) }, Err(broadcast::error::RecvError::Lagged(n)) => { tracing::warn!("Receiver lagging on channel {channel_id} by {n} events",); diff --git a/backend/src/model/event.rs b/backend/src/model/event.rs new file mode 100644 index 0000000..23c699c --- /dev/null +++ b/backend/src/model/event.rs @@ -0,0 +1,27 @@ +use rocket::serde::{Deserialize, Serialize}; +use sqlx::FromRow; +use chrono::{DateTime, Utc}; +use uuid::Uuid; + +/// ---------- Rocket routes ---------- +#[derive(Debug, Serialize, Deserialize, Clone, FromRow)] +pub struct ChatMsg { + pub id: Uuid, + pub display_name: Option, + pub user_id: i64, + pub text: String, + pub timestamp: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", content = "data")] +pub enum ChatEvent { + SendMessage(ChatMsg), + + /// for when a user explicitly edits a message + EditMessage { id: Uuid, msg: ChatMsg }, + + /// used for streaming content to a message + /// will not show up as edited + MessageAppendContent{ id: Uuid, content: String } +} \ No newline at end of file diff --git a/backend/src/model/mod.rs b/backend/src/model/mod.rs index 23732a7..e854308 100644 --- a/backend/src/model/mod.rs +++ b/backend/src/model/mod.rs @@ -1,3 +1,4 @@ pub mod auth; pub mod user; -pub mod space; \ No newline at end of file +pub mod space; +pub mod event; \ No newline at end of file diff --git a/backend/src/repo/message_repo.rs b/backend/src/repo/message_repo.rs index ea9b4a8..98c242e 100644 --- a/backend/src/repo/message_repo.rs +++ b/backend/src/repo/message_repo.rs @@ -1,68 +1,80 @@ -use crate::api::chat::ChatMsg; +use crate::model::event::ChatMsg; use crate::repo::Repo; use chrono::{DateTime, Utc}; use sqlx::PgPool; +use uuid::Uuid; #[derive(Clone)] pub struct MessageRepository { pool: PgPool } -impl Repo for MessageRepository { - type Target = ChatMsg; - - fn new(pool: PgPool) -> Self { +impl MessageRepository { + pub(crate) fn new(pool: PgPool) -> Self { Self { pool } } // TODO: caching with redis - async fn get_by_id(&self, id: i64) -> Option { + async fn get_by_id(&self, id: Uuid) -> Option { sqlx::query!( - "SELECT u.username, u.nickname, u.id as user_id, m.content, m.created_at + "SELECT m.id, u.username, u.nickname, u.id as user_id, m.content, m.created_at FROM messages m JOIN users u ON m.user_id = u.id WHERE m.id = $1", id ).fetch_optional(&self.pool).await.ok().flatten().map(|row| ChatMsg { - display_name: Some(row.nickname.unwrap_or(row.username)), - user_id: row.user_id, - text: row.content, - timestamp: row.created_at, - }) + id: row.id, + display_name: Some(row.nickname.unwrap_or(row.username)), + user_id: row.user_id, + text: row.content, + timestamp: row.created_at, + }) } -} -impl MessageRepository { // TODO! caching with redis pub async fn create_new( - &self, uid: i64, channel_id: i64, - text: &str, created_at: DateTime - ) -> Result { + &self, msg: ChatMsg, channel_id: i64 + ) -> Result<(), sqlx::Error> { sqlx::query!( - "INSERT INTO messages (channel_id, user_id, content, created_at) - VALUES ($1, $2, $3, $4) RETURNING id", + "INSERT INTO messages (id, channel_id, user_id, content, created_at) + VALUES ($1, $2, $3, $4, $5)", + msg.id, channel_id, - uid, + msg.user_id, + msg.text, + msg.timestamp + ).execute(&self.pool).await.map_err(|_| sqlx::Error::RowNotFound)?; + Ok(()) + } + + pub async fn update_text(&self, id: Uuid, text: &str) -> Result<(), sqlx::Error> { + sqlx::query!( + "UPDATE messages SET content = $1 WHERE id = $2", text, - created_at - ).fetch_optional(&self.pool).await.and_then(|row| row.map(|r| r.id).ok_or(sqlx::Error::RowNotFound)) + id + ).execute(&self.pool).await?; + Ok(()) } /// TODO: caching with redis - pub async fn get_by_channel(&self, channel_id: i64, limit: usize) - -> Result, sqlx::Error> { + pub async fn get_latest_by_channel_desc( + &self, channel_id: i64, limit: usize, page: usize + ) -> Result, sqlx::Error> { sqlx::query!( - "SELECT u.username, u.nickname, u.id as user_id, m.content, m.created_at + "SELECT m.id, u.username, u.nickname, u.id as user_id, m.content, m.created_at FROM messages m JOIN users u ON m.user_id = u.id WHERE m.channel_id = $1 - ORDER BY m.created_at DESC LIMIT $2", + ORDER BY m.created_at DESC + LIMIT $2 OFFSET $3", channel_id, - limit as i64 + limit as i64, + page as i64 ).fetch_all(&self.pool).await.map(|messages| { - messages.into_iter().rev().map(|msg| { + messages.into_iter().map(|msg| { ChatMsg { + id: msg.id, display_name: Some(msg.nickname.unwrap_or(msg.username)), user_id: msg.user_id, text: msg.content, diff --git a/backend/src/svc/chat_svc.rs b/backend/src/svc/chat_svc.rs index 0572f8a..9d75b39 100644 --- a/backend/src/svc/chat_svc.rs +++ b/backend/src/svc/chat_svc.rs @@ -1,12 +1,15 @@ -use crate::api::chat::ChatMsg; +use crate::model::event::{ChatEvent, ChatMsg}; use crate::error::{ApiResult, AppError}; use crate::repo::message_repo::MessageRepository; use crate::repo::{ChannelRepo, Repo, SpaceRepo, UserRepo}; use chrono::{DateTime, Utc}; -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; use std::sync::Arc; +use std::sync::mpsc::channel; +use std::time::Instant; use tokio::sync::broadcast::Sender; use tokio::sync::{broadcast, Mutex}; +use uuid::Uuid; use crate::model::space::SpaceDto; use crate::svc::llm_service::LlmService; @@ -15,15 +18,13 @@ use crate::svc::llm_service::LlmService; #[derive(Clone)] pub struct ChatService { users: Arc, - channels: Arc, + channel_repo: Arc, spaces: Arc, messages: MessageRepository, llm: LlmService, buffer_size: usize, - senders: Arc>>>, - - + channels: Arc>>, } impl ChatService { @@ -33,13 +34,13 @@ impl ChatService { channels: Arc, spaces: Arc, ) -> Self { Self { - channels, + channel_repo: channels, spaces, llm, users, messages, buffer_size, - senders: Arc::new(Mutex::new(std::collections::HashMap::new())), + channels: Arc::new(Mutex::new(std::collections::HashMap::new())), } } @@ -50,7 +51,7 @@ impl ChatService { let mut result = Vec::new(); for space in spaces { - let channels = self.channels.get_by_space_id(space.id).await?; + let channels = self.channel_repo.get_by_space_id(space.id).await?; result.push(SpaceDto { channels, id: space.id, @@ -65,8 +66,10 @@ impl ChatService { Ok(result) } - pub async fn get_messages(&self, channel_id: i64, limit: usize) -> ApiResult> { - let messages = self.messages.get_by_channel(channel_id, limit).await?; + pub async fn fetch_latest_messages_desc(&self, channel_id: i64, limit: usize) -> ApiResult> { + const PAGE: usize = 0; + + let messages = self.messages.get_latest_by_channel_desc(channel_id, limit, PAGE).await?; Ok(messages) } @@ -113,6 +116,7 @@ impl ChatService { .ok_or(AppError::NotFound)?; let message = ChatMsg { + id: Uuid::new_v4(), display_name: Some(user .nickname.clone() .unwrap_or_else(|| user.username.clone())), @@ -122,66 +126,124 @@ impl ChatService { }; self.publish(channel_id, message.clone()).await; - - let _msg_id = self.messages.create_new(uid, channel_id, text, created_at).await?; + self.messages.create_new(message.clone(), channel_id).await?; // TODO: caching w redis at repository layer - let svc_instance = self.clone(); - - let Some(text) = text.strip_prefix("/ask ") else { - return Ok(()) - }; - - if !svc_instance.llm.enabled() { - return Ok(()) + if !message.text.starts_with("/ask ") { + return Ok(()); } + + let svc_instance = self.clone(); + if !svc_instance.llm.enabled() { + return Ok(()); + } + + let context = self.get_history(channel_id, 25).await?; tokio::spawn(async move { + let sender = match svc_instance.channels.lock().await.get(&channel_id) { + Some(s) => s.get_sender(), + None => return, + }; + let response = svc_instance.llm - .query(&message) + .query(&message, &context, sender) .await; - if let Ok(reply) = response { - - tracing::info!("LLM reply: {}", reply.text); - - svc_instance.publish(channel_id, reply.clone()).await; - // TODO: cache response (or do with redis!) - if let Err(e) = svc_instance.messages - .create_new(reply.user_id, channel_id, &reply.text, reply.timestamp).await { - tracing::error!("Failed to persist LLM reply: {}", e); - } - - tracing::info!("LLM reply persisted"); - - } else { + let Ok(reply) = response else { tracing::warn!("Error contacting LLM: {:?}", response); + return; + }; + + // TODO: cache response (or do with redis!) + if let Err(e) = svc_instance.messages.create_new(reply, channel_id).await { + tracing::error!("Failed to persist LLM reply: {}", e); } + + tracing::debug!("Full LLM reply persisted"); }); Ok(()) } + async fn get_history(&self, channel_id: i64, limit: usize) -> ApiResult> { + let mut map = self.channels.lock().await; + if let Some(channel) = map.get(&channel_id) && !channel.cache.is_empty() { + Ok(channel.history().clone().into_iter().take(limit).collect()) + } else { + let messages: Vec<_> = self.messages.get_latest_by_channel_desc(channel_id, limit, 0).await?.into_iter().rev().collect(); + map.insert(channel_id, + ChannelState::new(self.buffer_size, Some(messages.clone().into())) + ); + Ok(messages) + } + } + /// Subscribe to the specified channel. - pub async fn subscribe(&self, channel_id: i64) -> broadcast::Receiver { - let mut map = self.senders.lock().await; - let sender = map + pub async fn subscribe(&self, channel_id: i64) -> broadcast::Receiver { + let mut map = self.channels.lock().await; + let channel = map .entry(channel_id) - .or_insert_with(|| broadcast::channel::(self.buffer_size).0); - sender.subscribe() + .or_insert_with(|| ChannelState::new(self.buffer_size, None)); + channel.subscribe() } // Private helper methods /// Publish a message to the specified channel. async fn publish(&self, channel_id: i64, msg: ChatMsg) { - let mut map = self.senders.lock().await; - let sender = map + let mut map = self.channels.lock().await; + let channel = map .entry(channel_id) - .or_insert_with(|| broadcast::channel::(self.buffer_size).0); - let _ = sender.send(msg); + .or_insert_with(|| ChannelState::new(self.buffer_size, None)); + channel.send(msg); + } +} + +#[derive(Clone)] +pub struct ChannelState { + sender: Sender, + cache: VecDeque, + last_updated: Instant, +} + +impl ChannelState { + + const MAX_HISTORY_SIZE: usize = 100; + + #[must_use] + pub fn new(buffer_size: usize, history: Option>) -> Self { + Self { + sender: broadcast::channel(buffer_size).0, + cache: history.unwrap_or_default(), + last_updated: Instant::now(), + } } + pub fn history(&self) -> &VecDeque { + &self.cache + } + pub fn get_sender(&self) -> Sender { + self.sender.clone() + } + + #[must_use] + pub fn send(&mut self, msg: ChatMsg) { + while self.cache.len() >= Self::MAX_HISTORY_SIZE { + self.cache.pop_front(); + } + + self.cache.push_back(msg.clone()); + + if self.sender.send(ChatEvent::SendMessage(msg)).is_err() { + tracing::warn!("Sent message to channel with no subscribers"); + } + } + + #[must_use] + pub fn subscribe(&self) -> broadcast::Receiver { + self.sender.subscribe() + } +} -} \ No newline at end of file diff --git a/backend/src/svc/llm_service.rs b/backend/src/svc/llm_service.rs index 9b0cace..2d8ddfc 100644 --- a/backend/src/svc/llm_service.rs +++ b/backend/src/svc/llm_service.rs @@ -1,3 +1,13 @@ +use std::env; +use std::sync::LazyLock; +use futures_util::StreamExt; +use serde::{Deserialize, Serialize}; +use tokio::sync::broadcast::Sender; +use uuid::Uuid; +use crate::model::event::{ChatEvent, ChatMsg}; +use crate::error::{ApiResult, AppError}; + + #[derive(Clone)] pub struct LlmService; @@ -15,71 +25,139 @@ impl LlmService { LMSTUDIO_URL.is_some() } - pub async fn query(&self, message: &ChatMsg) -> ApiResult { + pub async fn query(&self, message: &ChatMsg, context: &[ChatMsg], sender: Sender) -> ApiResult { let Some(url) = LMSTUDIO_URL.clone() else { return Err(AppError::internal("AI not enabled!")) }; - let model = LMSTUDIO_MODEL.clone().unwrap_or_else(|| "gpt-oss-20b".into()); - - let client = reqwest::Client::new(); - - // Build the request body - let payload = LlmRequest { - model, // whatever model you run locally - messages: vec![Message { - role: "user".into(), - content: message.text.clone(), - }], + let reply_id = Uuid::new_v4(); + let timestamp = chrono::Utc::now(); + let mut reply = ChatMsg { + id: reply_id, + display_name: Some("llm".into()), + user_id: 0, + text: String::new(), + timestamp, }; + let _ = sender.send(ChatEvent::SendMessage(reply.clone())); - // POST to lm‑studio (default 127.0.0.1:1234) - let resp = client + let mut messages: Vec = Vec::new(); + let system_prompt = format!( + "You are a helpful assistant in a group chat. \ + You are talking to '{}'. \ + Keep responses concise and conversational.", + message.display_name.as_deref().unwrap_or("unknown"), + ); + messages.push(Message { role: "system".into(), content: system_prompt }); + + for msg in context { + let role = if msg.user_id == 0 { + "assistant" // your LLM user_id convention + } else { + "user" + }; + messages.push(Message { + role: role.into(), + content: format!( + "{}: {}", + msg.display_name.as_deref().unwrap_or("unknown"), + msg.text + ), + }); + } + + messages.push(Message { + role: "user".into(), + content: format!( + "{}: {}", + message.display_name.as_deref().unwrap_or("unknown"), + message.text.trim_start_matches("/ask ") // strip the command prefix + ), + }); + + let Ok(resp) = reqwest::Client::new() .post(url) - .json(&payload) + .json(&LlmRequest { + think: false, + model: LMSTUDIO_MODEL + .clone() + .unwrap_or_else(|| "gpt-oss-20b".into()), + messages, + stream: true, + }) .send() .await - .map_err(|_| AppError::internal("Failed to make request to LLM server"))?; + else { + tracing::warn!("Failed to reach LLM"); + let _ = sender.send(ChatEvent::MessageAppendContent { + id: reply_id, + content: String::from("I'm not available right now. Please try again later.") + }); + return Err(AppError::internal("Failed to reach LLM")); + }; - // The API returns a JSON with `choices[].message.content` - #[derive(Deserialize)] - struct LlmResponse { - choices: Vec, - } - #[derive(Deserialize)] - struct Choice { - message: Message, + + let mut full_text = String::new(); + let mut buffer = String::new(); + let mut stream = resp.bytes_stream(); + + while let Some(chunk) = stream.next().await { + let chunk = chunk.map_err(|_| AppError::internal("Stream error"))?; + buffer.push_str(&String::from_utf8_lossy(&chunk)); + + while let Some(pos) = buffer.find('\n') { + let line = buffer[..pos].trim().to_string(); + buffer = buffer[pos + 1..].to_string(); + + if line == "data: [DONE]" { + break; + } + + if let Some(json) = line.strip_prefix("data: ") + && let Ok(parsed) = serde_json::from_str::(json) + && let Some(content) = parsed.choices[0].delta.content.as_ref() { + full_text.push_str(content); + let _ = sender.send(ChatEvent::MessageAppendContent { + id: reply_id, + content: content.clone(), + }); + } + } } - let llm_resp: LlmResponse = resp - .json() - .await - .map_err(|_| AppError::internal("Failed to parse LLM response"))?; + reply.text = full_text; - Ok(ChatMsg { - display_name: Some(String::from("llm")), - user_id: 0, - text: llm_resp.choices[0].message.content.clone(), - timestamp: chrono::Utc::now(), - }) + Ok(reply) } } -use std::env; -use std::sync::LazyLock; -// src/llm.rs -use serde::{Deserialize, Serialize}; -use crate::api::chat::ChatMsg; -use crate::error::{ApiResult, AppError}; -use crate::svc::chat_svc::ChatService; #[derive(Serialize)] struct LlmRequest { model: String, messages: Vec, + stream: bool, + think: bool, } +#[derive(Deserialize)] +struct StreamingResponse { + choices: Vec, +} + +#[derive(Deserialize)] +struct StreamingChoice { + delta: Delta, + #[serde(default)] + finish_reason: Option, +} + +#[derive(Deserialize)] +struct Delta { + #[serde(default)] + content: Option, +} #[derive(Serialize, Deserialize)] struct Message { role: String, // "user" or "assistant" diff --git a/backend/src/svc/mod.rs b/backend/src/svc/mod.rs index 7c366a3..ba177dd 100644 --- a/backend/src/svc/mod.rs +++ b/backend/src/svc/mod.rs @@ -3,4 +3,5 @@ pub mod chat_svc; pub mod settings_svc; pub mod user_svc; pub mod access_token_svc; -pub mod llm_service; \ No newline at end of file +pub mod llm_service; +pub mod relationship_svc; \ No newline at end of file diff --git a/backend/src/svc/relationship_svc.rs b/backend/src/svc/relationship_svc.rs new file mode 100644 index 0000000..1f34fb0 --- /dev/null +++ b/backend/src/svc/relationship_svc.rs @@ -0,0 +1,6 @@ +pub struct RelationshipService {} + +impl RelationshipService { + pub fn new() -> Self { Self {} } +} + diff --git a/backend/tests/chat_test.py b/backend/tests/chat_test.py index e69de29..8cebea8 100644 --- a/backend/tests/chat_test.py +++ b/backend/tests/chat_test.py @@ -0,0 +1,228 @@ +import argparse +import json +import threading +import time +from dataclasses import dataclass +from getpass import getpass +from typing import List + +import requests + + +BASE_URL = "http://localhost:8000" + + +@dataclass +class AuthResult: + token: str + + +def signup(session: requests.Session, email: str, username: str, password: str, access_token: str) -> None: + url = f"{BASE_URL}/api/signup" + payload = { + "email": email, + "username": username, + "password": password, + "access_token": access_token, + } + resp = session.post(url, json=payload, timeout=10) + resp.raise_for_status() + + +def login(session: requests.Session, username: str, password: str) -> AuthResult: + url = f"{BASE_URL}/api/login" + payload = { + "username": username, + "password": password, + } + resp = session.post(url, json=payload, timeout=10) + resp.raise_for_status() + + data = resp.json() + token = data.get("token") + if not token: + raise RuntimeError(f"Login response did not contain token: {data}") + return AuthResult(token=token) + + +def post_message(session: requests.Session, channel_id: int, token: str, text: str, display_name: str, user_id: int) -> None: + url = f"{BASE_URL}/api/chat/{channel_id}" + payload = { + "display_name": display_name, + "user_id": user_id, + "text": text, + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + } + headers = {"Authorization": f"Bearer {token}"} + resp = session.post(url, json=payload, headers=headers, timeout=10) + resp.raise_for_status() + + +def read_sse_messages( + session: requests.Session, + channel_id: int, + token: str, + expected_count: int, + timeout_s: int, + capture_live_messages: threading.Event, +) -> List[dict]: + url = f"{BASE_URL}/api/events/{channel_id}" + headers = { + "Authorization": f"Bearer {token}", + "Accept": "text/event-stream", + } + + received: List[dict] = [] + deadline = time.monotonic() + timeout_s + + try: + with session.get(url, headers=headers, stream=True, timeout=(5, timeout_s)) as resp: + resp.raise_for_status() + + event_data_lines: List[str] = [] + + for raw_line in resp.iter_lines(decode_unicode=True): + if time.monotonic() > deadline: + break + + if raw_line is None: + continue + + line = raw_line.strip() + + if not line: + if event_data_lines: + joined = "\n".join(event_data_lines) + event_data_lines.clear() + try: + obj = json.loads(joined) + except json.JSONDecodeError: + continue + + if capture_live_messages.is_set(): + received.append(obj) + if expected_count > 0 and len(received) >= expected_count: + break + else: + print(f"Discarding message: {obj}") + continue + + if line.startswith("data:"): + event_data_lines.append(line[len("data:"):].strip()) + + except requests.exceptions.Timeout: + print("Timeout while reading SSE.") + except requests.exceptions.RequestException as exc: + print(f"Error reading SSE: {exc}") + + return received + + +def prompt_nonempty(label: str, secret: bool = False) -> str: + while True: + value = getpass(label) if secret else input(label) + value = value.strip() + if value: + return value + print("Please enter a value.") + + +def main() -> int: + parser = argparse.ArgumentParser(description="Chat integration test against localhost:8000") + parser.add_argument("--existing-account", action="store_true", + help="Skip signup and only log in with an existing account") + parser.add_argument("--email", default=None) + parser.add_argument("--username", default=None) + parser.add_argument("--password", default=None) + parser.add_argument("--access-token", default=None, + help="Required only for signup mode") + parser.add_argument("--channel-id", type=int, default=1) + parser.add_argument("--message-count", type=int, default=5) + parser.add_argument("--timeout", type=int, default=15) + args = parser.parse_args() + + session = requests.Session() + + if args.existing_account: + username = args.username or prompt_nonempty("Username: ") + password = args.password or prompt_nonempty("Password: ", secret=True) + else: + email = args.email or prompt_nonempty("Email: ") + username = args.username or prompt_nonempty("Username: ") + password = args.password or prompt_nonempty("Password: ", secret=True) + access_token = args.access_token or prompt_nonempty("Access token: ") + + print("[1/5] Signing up...") + try: + signup(session, email, username, password, access_token) + print(" signup ok") + except requests.HTTPError as e: + print(f" signup returned HTTP error: {e}") + print(" continuing to login...") + + print("[2/5] Logging in...") + auth = login(session, username, password) + print(" login ok") + print(f" token: {auth.token[:12]}...") + + print("[3/5] Opening event stream...") + received_messages: List[dict] = [] + capture_live_messages = threading.Event() + stream_done = threading.Event() + + def stream_reader() -> None: + nonlocal received_messages + try: + received_messages = read_sse_messages( + session=session, + channel_id=args.channel_id, + token=auth.token, + expected_count=args.message_count, + timeout_s=args.timeout, + capture_live_messages=capture_live_messages, + ) + finally: + stream_done.set() + + t = threading.Thread(target=stream_reader, daemon=True) + t.start() + + # Give the server time to flush backlog on this same stream connection. + time.sleep(1.0) + + print("[4/5] Starting to capture live messages and sending messages...") + capture_live_messages.set() + + sent_texts = [f"Message {i}" for i in range(args.message_count)] + for i, text in enumerate(sent_texts): + post_message( + session=session, + channel_id=args.channel_id, + token=auth.token, + text=text, + display_name=username, + user_id=1, + ) + print(f" sent {i + 1}/{args.message_count}: {text}") + time.sleep(0.1) + + stream_done.wait(timeout=args.timeout) + t.join(timeout=1) + + print("\nReceived messages:") + for i, msg in enumerate(received_messages, start=1): + print(f" {i}. {msg}") + + received_texts = [m.get("text") for m in received_messages if isinstance(m, dict)] + + for text in sent_texts: + if text not in received_texts: + print(f"\nFAIL: missing message: {text}") + return 1 + + print("\nPASS: login and message delivery test succeeded.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) \ No newline at end of file