This commit is contained in:
2026-06-03 19:12:23 +01:00
parent 2f34976f3e
commit 7e001d8769
27 changed files with 727 additions and 188 deletions
+1 -1
View File
@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="utf-8"?> <?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android" <manifest xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"> xmlns:tools="http://tools.android.com/tools">
<uses-permission android:name="android.permission.INTERNET" /> <uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.POST_NOTIFICATIONS"/> <uses-permission android:name="android.permission.POST_NOTIFICATIONS"/>
@@ -1,9 +1,13 @@
package dev.zxq5.chatapp.android package dev.zxq5.chatapp.android
import android.Manifest
import android.os.Build
import android.os.Bundle import android.os.Bundle
import androidx.activity.ComponentActivity import androidx.activity.ComponentActivity
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.compose.setContent import androidx.activity.compose.setContent
import androidx.activity.enableEdgeToEdge import androidx.activity.enableEdgeToEdge
import androidx.activity.result.contract.ActivityResultContracts
import androidx.compose.foundation.border import androidx.compose.foundation.border
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxSize
@@ -62,9 +66,22 @@ class MainActivity : ComponentActivity() {
val currentScreen by chatViewModel.currentScreen.collectAsState() val currentScreen by chatViewModel.currentScreen.collectAsState()
val selectedChannelId by chatViewModel.channelId.collectAsState() val selectedChannelId by chatViewModel.channelId.collectAsState()
// Permission request launcher
val launcher = rememberLauncherForActivityResult(
contract = ActivityResultContracts.RequestPermission(),
onResult = { isGranted ->
if (isGranted && authState == AuthState.Authenticated) {
MessageStreamService.start(this@MainActivity)
}
}
)
LaunchedEffect(authState) { LaunchedEffect(authState) {
when (authState) { when (authState) {
AuthState.Authenticated -> { AuthState.Authenticated -> {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) {
launcher.launch(Manifest.permission.POST_NOTIFICATIONS)
}
MessageStreamService.start(this@MainActivity) MessageStreamService.start(this@MainActivity)
chatViewModel.loadAccessibleChannels() chatViewModel.loadAccessibleChannels()
} }
@@ -1,7 +1,9 @@
@file:OptIn(ExperimentalUuidApi::class)
package dev.zxq5.chatapp.android.api package dev.zxq5.chatapp.android.api
import dev.zxq5.chatapp.android.BuildConfig.BASE_URL import dev.zxq5.chatapp.android.BuildConfig.BASE_URL
import dev.zxq5.chatapp.android.api.model.Message import dev.zxq5.chatapp.android.api.model.ChatEvent
import dev.zxq5.chatapp.android.api.model.SendMessage import dev.zxq5.chatapp.android.api.model.SendMessage
import dev.zxq5.chatapp.android.api.model.SpaceDto import dev.zxq5.chatapp.android.api.model.SpaceDto
import io.ktor.client.HttpClient import io.ktor.client.HttpClient
@@ -25,6 +27,8 @@ import kotlinx.coroutines.flow.flow
import kotlinx.serialization.json.Json import kotlinx.serialization.json.Json
import kotlin.time.Clock import kotlin.time.Clock
import kotlin.time.ExperimentalTime import kotlin.time.ExperimentalTime
import kotlin.uuid.ExperimentalUuidApi
import kotlin.uuid.Uuid
class ChatClient(private val token: String) { class ChatClient(private val token: String) {
@@ -45,18 +49,18 @@ class ChatClient(private val token: String) {
suspend fun sendMessage(channelId: Long, userId: Int, text: String) { suspend fun sendMessage(channelId: Long, userId: Int, text: String) {
http.post("${BASE_URL}/api/chat/$channelId") { http.post("${BASE_URL}/api/chat/$channelId") {
contentType(ContentType.Application.Json) contentType(ContentType.Application.Json)
setBody(SendMessage(user_id = userId, text = text, timestamp = Clock.System.now())) setBody(SendMessage(id = Uuid.random(), user_id = userId, text = text, timestamp = Clock.System.now()))
} }
} }
fun messageStream(channelId: Long): Flow<Message> = flow { fun eventStream(channelId: Long): Flow<ChatEvent> = flow {
http.prepareGet("${BASE_URL}/api/events/$channelId").execute { response -> http.prepareGet("${BASE_URL}/api/events/$channelId").execute { response ->
val channel = response.bodyAsChannel() val channel = response.bodyAsChannel()
while (!channel.isClosedForRead) { while (!channel.isClosedForRead) {
val line = channel.readLine() ?: break val line = channel.readLine() ?: break
if (line.startsWith("data:")) { if (line.startsWith("data:")) {
val json = line.removePrefix("data:").trim() val json = line.removePrefix("data:").trim()
runCatching { Json.decodeFromString<Message>(json) } runCatching { Json.decodeFromString<ChatEvent>(json) }
.onSuccess { emit(it) } .onSuccess { emit(it) }
} }
} }
@@ -0,0 +1,60 @@
@file:OptIn(ExperimentalUuidApi::class, ExperimentalTime::class)
package dev.zxq5.chatapp.android.api.model
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.SerialName
import kotlinx.serialization.json.JsonClassDiscriminator
import kotlin.time.ExperimentalTime
import kotlin.time.Instant
import kotlinx.serialization.Serializable
import kotlin.uuid.ExperimentalUuidApi
import kotlin.uuid.Uuid
@OptIn(ExperimentalSerializationApi::class)
@Serializable
@JsonClassDiscriminator("type")
sealed class ChatEvent {
@Serializable
@SerialName("SendMessage")
data class SendMessage(
val data: Message
) : ChatEvent()
@Serializable
@SerialName("EditMessage")
data class EditMessage(
val data: EditMessageContent
) : ChatEvent()
@Serializable
@SerialName("MessageAppendContent")
data class MessageAppendContent(
val data: AppendContent
) : ChatEvent()
}
// tuple variants like (i64, ChatMsg) and (i64, String)
// need wrapper classes since kotlinx can't deserialise
// bare JSON arrays into data classes directly
@Serializable
data class EditMessageContent(
val id: Uuid,
val message: Message
)
@Serializable
data class AppendContent (
val id: Uuid,
val content: String
)
@Serializable
data class Message (
val id: Uuid,
val user_id: Int,
val display_name: String,
val text: String,
val timestamp: Instant
)
@@ -1,13 +0,0 @@
package dev.zxq5.chatapp.android.api.model
import kotlinx.serialization.Serializable
import kotlin.time.ExperimentalTime
import kotlin.time.Instant
@Serializable
data class Message @OptIn(ExperimentalTime::class) constructor(
val user_id: Int,
val display_name: String,
val text: String,
val timestamp: Instant
)
@@ -3,9 +3,12 @@ package dev.zxq5.chatapp.android.api.model
import kotlinx.serialization.Serializable import kotlinx.serialization.Serializable
import kotlin.time.ExperimentalTime import kotlin.time.ExperimentalTime
import kotlin.time.Instant import kotlin.time.Instant
import kotlin.uuid.ExperimentalUuidApi
import kotlin.uuid.Uuid
@Serializable @Serializable
data class SendMessage @OptIn(ExperimentalTime::class) constructor( data class SendMessage @OptIn(ExperimentalTime::class, ExperimentalUuidApi::class) constructor(
val id: Uuid,
val user_id: Int, val user_id: Int,
val text: String, val text: String,
val timestamp: Instant val timestamp: Instant
@@ -6,6 +6,7 @@ import android.content.Intent
import android.os.IBinder import android.os.IBinder
import android.util.Log import android.util.Log
import dev.zxq5.chatapp.android.ChatApplication import dev.zxq5.chatapp.android.ChatApplication
import dev.zxq5.chatapp.android.api.model.ChatEvent
import dev.zxq5.chatapp.android.data.repository.ChatRepository import dev.zxq5.chatapp.android.data.repository.ChatRepository
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
@@ -37,8 +38,6 @@ class MessageStreamService : Service() {
fun start(context: Context) { fun start(context: Context) {
val intent = Intent(context, MessageStreamService::class.java) val intent = Intent(context, MessageStreamService::class.java)
// Use startService to avoid the requirement for a persistent notification.
// This also prevents ForegroundServiceDidNotStartInTimeException.
context.startService(intent) context.startService(intent)
} }
@@ -64,17 +63,20 @@ class MessageStreamService : Service() {
if (channelId == null) return if (channelId == null) return
currentStreamJob = serviceScope.launch { currentStreamJob = serviceScope.launch {
chatRepository.messageStream(channelId) chatRepository.eventStream(channelId)
.catch { e -> Log.e("Service", "Stream error", e) } .catch { e -> Log.e("Service", "Stream error", e) }
.collect { message -> .collect { event ->
// Only show notification when an event (new message) is received // Only show notification when an event (new message) is received
// and the app is not in the foreground on this channel. // and the app is not in the foreground on this channel.
if (!ChatApplication.AppState.isInForeground || activeChannelId != channelId) { if (!ChatApplication.AppState.isInForeground || activeChannelId != channelId) {
notificationService.showMessageNotification( when (event) {
conversationId = channelId.toString(), is ChatEvent.SendMessage -> notificationService.showMessageNotification(
senderName = message.display_name, conversationId = channelId.toString(),
messagePreview = message.text senderName = event.data.display_name,
) messagePreview = event.data.text
)
else -> {}
}
} }
} }
} }
@@ -1,6 +1,7 @@
package dev.zxq5.chatapp.android.data.repository package dev.zxq5.chatapp.android.data.repository
import dev.zxq5.chatapp.android.api.ChatClient import dev.zxq5.chatapp.android.api.ChatClient
import dev.zxq5.chatapp.android.api.model.ChatEvent
import dev.zxq5.chatapp.android.core.data.TokenStore import dev.zxq5.chatapp.android.core.data.TokenStore
import dev.zxq5.chatapp.android.api.model.Message import dev.zxq5.chatapp.android.api.model.Message
import dev.zxq5.chatapp.android.api.model.SpaceDto import dev.zxq5.chatapp.android.api.model.SpaceDto
@@ -43,8 +44,8 @@ class ChatRepository(private val tokenStore: TokenStore) {
getChatClient()?.sendMessage(channelId, userId, text) getChatClient()?.sendMessage(channelId, userId, text)
} }
fun messageStream(channelId: Long): Flow<Message> { fun eventStream(channelId: Long): Flow<ChatEvent> {
_lastActiveChannel = channelId _lastActiveChannel = channelId
return getChatClient()?.messageStream(channelId) ?: emptyFlow() return getChatClient()?.eventStream(channelId) ?: emptyFlow()
} }
} }
@@ -1,12 +1,13 @@
@file:OptIn(ExperimentalUuidApi::class)
package dev.zxq5.chatapp.android.feature.chat package dev.zxq5.chatapp.android.feature.chat
import android.util.Log import android.util.Log
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import dev.zxq5.chatapp.android.api.model.Channel import dev.zxq5.chatapp.android.api.model.ChatEvent
import dev.zxq5.chatapp.android.data.repository.ChatRepository import dev.zxq5.chatapp.android.data.repository.ChatRepository
import dev.zxq5.chatapp.android.api.model.Message import dev.zxq5.chatapp.android.api.model.Message
import dev.zxq5.chatapp.android.api.model.Space
import dev.zxq5.chatapp.android.api.model.SpaceDto import dev.zxq5.chatapp.android.api.model.SpaceDto
import dev.zxq5.chatapp.android.core.service.MessageStreamService import dev.zxq5.chatapp.android.core.service.MessageStreamService
import io.ktor.client.plugins.ResponseException import io.ktor.client.plugins.ResponseException
@@ -17,6 +18,8 @@ import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.catch import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.update import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlin.time.ExperimentalTime
import kotlin.uuid.ExperimentalUuidApi
class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() { class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
@@ -70,6 +73,7 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
} }
} }
@OptIn(ExperimentalTime::class)
private fun observeChannel() { private fun observeChannel() {
viewModelScope.launch { viewModelScope.launch {
_channelId.collect { id -> _channelId.collect { id ->
@@ -78,7 +82,7 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
_channelError.value = null _channelError.value = null
if (id != null) { if (id != null) {
streamJob = launch { streamJob = launch {
chatRepository.messageStream(id) chatRepository.eventStream(id)
.catch { e -> .catch { e ->
Log.e("Chat", "Stream error", e) Log.e("Chat", "Stream error", e)
if (e is ResponseException && e.response.status == HttpStatusCode.Unauthorized) { if (e is ResponseException && e.response.status == HttpStatusCode.Unauthorized) {
@@ -87,8 +91,31 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
_channelError.value = "Connection lost: ${e.message}" _channelError.value = "Connection lost: ${e.message}"
} }
} }
.collect { message -> .collect { event ->
_messages.update { it + message } when (event) {
is ChatEvent.SendMessage -> {
_messages.update { it + event.data }
}
is ChatEvent.EditMessage -> {
_messages.update { messages ->
messages.map {
if (it.id == event.data.id) event.data.message
else it
}
}
}
is ChatEvent.MessageAppendContent -> {
_messages.update { messages ->
messages.map { msg ->
if (msg.id == event.data.id) {
msg.copy(text = msg.text + event.data.content)
} else {
msg
}
}
}
}
}
} }
} }
} }
@@ -1,3 +1,5 @@
@file:OptIn(ExperimentalUuidApi::class)
package dev.zxq5.chatapp.android.feature.chat package dev.zxq5.chatapp.android.feature.chat
import androidx.compose.foundation.BorderStroke import androidx.compose.foundation.BorderStroke
@@ -65,6 +67,7 @@ import dev.zxq5.chatapp.android.api.model.Message
import java.text.DateFormat import java.text.DateFormat
import java.util.Date import java.util.Date
import kotlin.time.ExperimentalTime import kotlin.time.ExperimentalTime
import kotlin.uuid.ExperimentalUuidApi
@Composable @Composable
fun ChatScreen( fun ChatScreen(
@@ -277,7 +280,7 @@ fun MessageScreen(channelId: Long, viewModel: ChatViewModel, onBack: () -> Unit)
modifier = Modifier.weight(1f).padding(horizontal = 16.dp), modifier = Modifier.weight(1f).padding(horizontal = 16.dp),
verticalArrangement = Arrangement.spacedBy(10.dp) verticalArrangement = Arrangement.spacedBy(10.dp)
) { ) {
items(messages) { message -> items(messages, key = { it.id }) { message ->
MessageBubble(message, currentUserId) MessageBubble(message, currentUserId)
} }
item { Spacer(Modifier.height(10.dp)) } item { Spacer(Modifier.height(10.dp)) }
@@ -378,7 +381,7 @@ fun MessageBubble(message: Message, currentUserId: Int?) {
horizontalAlignment = if (isMe) Alignment.End else Alignment.Start horizontalAlignment = if (isMe) Alignment.End else Alignment.Start
) { ) {
Surface( Surface(
color = if (isMe) MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f) else MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.2f), color = if (isMe) MaterialTheme.colorScheme.surfaceVariant else MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.2f),
shape = RoundedCornerShape( shape = RoundedCornerShape(
topStart = 14.dp, topStart = 14.dp,
topEnd = 14.dp, topEnd = 14.dp,
@@ -388,14 +391,7 @@ fun MessageBubble(message: Message, currentUserId: Int?) {
border = border(0.5.dp, MaterialTheme.colorScheme.outlineVariant.copy(alpha = 0.5f)) border = border(0.5.dp, MaterialTheme.colorScheme.outlineVariant.copy(alpha = 0.5f))
) { ) {
Column(modifier = Modifier.padding(horizontal = 11.dp, vertical = 8.dp)) { Column(modifier = Modifier.padding(horizontal = 11.dp, vertical = 8.dp)) {
if (!isMe) {
Text(
message.display_name?.lowercase() ?: "unknown",
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.primary.copy(alpha = 0.7f),
modifier = Modifier.padding(bottom = 2.dp)
)
}
Text( Text(
text = message.text, text = message.text,
style = MaterialTheme.typography.bodyLarge, style = MaterialTheme.typography.bodyLarge,
@@ -403,10 +399,11 @@ fun MessageBubble(message: Message, currentUserId: Int?) {
) )
} }
} }
Text( Text(
text = time, text = if (!isMe) message.display_name.lowercase() + " . " + time else time,
style = MaterialTheme.typography.labelSmall, style = MaterialTheme.typography.labelMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.3f), color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f),
modifier = Modifier.padding(horizontal = 4.dp, vertical = 2.dp) modifier = Modifier.padding(horizontal = 4.dp, vertical = 2.dp)
) )
} }
@@ -2,7 +2,7 @@ package dev.zxq5.chatapp.android.ui.theme
import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.Color
val Black = Color(0xFF0A0A0A) val Black = Color(0xFF000000)
val DarkGrey = Color(0xFF0D0D0D) val DarkGrey = Color(0xFF0D0D0D)
val Grey = Color(0xFF141414) val Grey = Color(0xFF141414)
val LightGrey = Color(0xFF1E1E1E) val LightGrey = Color(0xFF1E1E1E)
+6
View File
@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="uk.co.ben_gibson.git.link.SettingsState">
<option name="host" value="e0f86390-1091-4871-8aeb-f534fbc99cf0" />
</component>
</project>
+1 -1
View File
@@ -1,7 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="DataSourcePerFileMappings"> <component name="DataSourcePerFileMappings">
<file url="file://$PROJECT_DIR$/sql/schema.sql" value="b14acf5d-6750-469b-8aea-59c8343eb11c" /> <file url="file://$PROJECT_DIR$/sql/test.sql" value="b14acf5d-6750-469b-8aea-59c8343eb11c" />
<file url="file://$PROJECT_DIR$/src/repo/user_repo.rs" value="b14acf5d-6750-469b-8aea-59c8343eb11c" /> <file url="file://$PROJECT_DIR$/src/repo/user_repo.rs" value="b14acf5d-6750-469b-8aea-59c8343eb11c" />
</component> </component>
</project> </project>
+2 -1
View File
@@ -1,7 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="SqlDialectMappings"> <component name="SqlDialectMappings">
<file url="file://$PROJECT_DIR$/sql/schema.sql" dialect="PostgreSQL" /> <file url="file://$PROJECT_DIR$/migrations/20260412200102_message_id_to_uuid.sql" dialect="PostgreSQL" />
<file url="file://$PROJECT_DIR$/sql/test.sql" dialect="PostgreSQL" />
<file url="PROJECT" dialect="PostgreSQL" /> <file url="PROJECT" dialect="PostgreSQL" />
</component> </component>
</project> </project>
+3 -3
View File
@@ -12,7 +12,7 @@ image = "0.25.8"
jsonwebtoken = { version = "10.3.0", features = ["rust_crypto"] } jsonwebtoken = { version = "10.3.0", features = ["rust_crypto"] }
rand = "0.8" rand = "0.8"
redis = { version = "0.25.4", features = ["tokio-comp"] } redis = { version = "0.25.4", features = ["tokio-comp"] }
reqwest = { version = "0.12.23", features = ["json"] } reqwest = { version = "0.12.23", features = ["json", "stream"] }
rocket = { version = "0.5.1", features = ["json", "secrets"] } rocket = { version = "0.5.1", features = ["json", "secrets"] }
rocket_cors = "0.6.0" rocket_cors = "0.6.0"
rocket_db_pools = { version = "0.2.0", features = ["deadpool_redis", "sqlx_macros", "sqlx_postgres"] } rocket_db_pools = { version = "0.2.0", features = ["deadpool_redis", "sqlx_macros", "sqlx_postgres"] }
@@ -20,11 +20,11 @@ rocket_dyn_templates = { version = "0.2.0", features = ["tera"] }
serde = { version = "1.0.228", features = ["derive"] } serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0.145" serde_json = "1.0.145"
sha2 = "0.10.9" sha2 = "0.10.9"
sqlx = { version = "0.7.4", features = ["chrono", "macros", "postgres", "time"] } sqlx = { version = "0.7.4", features = ["chrono", "macros", "postgres", "time", "uuid"] }
tokio = { version = "1.47.1", features = ["full"] } tokio = { version = "1.47.1", features = ["full"] }
totp-rs = { version = "5.7.0", features = ["gen_secret", "qr", "rand"] } totp-rs = { version = "5.7.0", features = ["gen_secret", "qr", "rand"] }
tracing = "0.1.44" tracing = "0.1.44"
uuid = { version = "1.18.1", features = ["v4"] } uuid = { version = "1.18.1", features = ["serde", "v4"] }
thiserror = "1.0.69" thiserror = "1.0.69"
utoipa = { version = "5.4.0", features = ["rocket_extras", "chrono"] } utoipa = { version = "5.4.0", features = ["rocket_extras", "chrono"] }
clap = { version = "4.5", features = ["derive"] } clap = { version = "4.5", features = ["derive"] }
+1 -1
View File
@@ -1,7 +1,7 @@
[debug] [debug]
secret_key = "yYhvCGnRh/TrcHtB8sZqCFifrVmJxoKFLBYw/WWBZeU=" secret_key = "yYhvCGnRh/TrcHtB8sZqCFifrVmJxoKFLBYw/WWBZeU="
address = "0.0.0.0" address = "0.0.0.0"
port = 8000 port = 8080
[debug.databases.postgres_db] [debug.databases.postgres_db]
url = "postgresql://chatapp:chatapp@100.118.108.58:5432/chatapp_dev" url = "postgresql://chatapp:chatapp@100.118.108.58:5432/chatapp_dev"
@@ -0,0 +1,9 @@
ALTER TABLE attachments DROP CONSTRAINT attachments_message_id_fkey;
ALTER TABLE messages ALTER COLUMN id DROP DEFAULT;
ALTER TABLE messages ALTER COLUMN id TYPE uuid USING gen_random_uuid();
ALTER TABLE messages ALTER COLUMN id SET DEFAULT gen_random_uuid();
ALTER TABLE attachments ALTER COLUMN message_id TYPE uuid USING gen_random_uuid();
ALTER TABLE attachments ADD CONSTRAINT attachments_message_id_fkey
FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE CASCADE;
+17
View File
@@ -0,0 +1,17 @@
WITH space1 AS (
INSERT INTO spaces (name, description, owner_id)
VALUES ('general', 'Boring chat idk', 1)
RETURNING id
),
space2 AS (
INSERT INTO spaces (name, description, owner_id)
VALUES ('Gaming', 'we lose games', 1)
RETURNING id
)
INSERT INTO channels (name, description, space_id)
SELECT 'General', 'General chat', id FROM space1 UNION ALL
SELECT 'Coding', 'Coding stuff', id FROM space1 UNION ALL
SELECT 'AI', '"/ask" here pls :)', id FROM space1 UNION ALL
SELECT 'The Game', '(You lost)', id FROM space2 UNION ALL
SELECT 'Backrooms', 'Beware of Smilers', id FROM space2 UNION ALL
SELECT 'SE', 'Space/Software engineering.', id FROM space2;
+9 -16
View File
@@ -9,15 +9,7 @@ use rocket::{Shutdown, State, ___internal_EventStream as EventStream};
use sqlx::FromRow; use sqlx::FromRow;
use tokio::select; use tokio::select;
use tokio::sync::broadcast; use tokio::sync::broadcast;
use crate::model::event::{ChatEvent, ChatMsg};
/// ---------- Rocket routes ----------
#[derive(Debug, Serialize, Deserialize, Clone, FromRow)]
pub struct ChatMsg {
pub display_name: Option<String>,
pub user_id: i64,
pub text: String,
pub timestamp: DateTime<Utc>,
}
#[post("/chat/<channel_id>", format = "json", data = "<msg>")] #[post("/chat/<channel_id>", format = "json", data = "<msg>")]
pub async fn post_message( pub async fn post_message(
@@ -36,24 +28,25 @@ pub async fn event_stream(
mut shutdown: Shutdown, mut shutdown: Shutdown,
channel_id: i64, channel_id: i64,
) -> ApiResult<EventStream![]> { ) -> ApiResult<EventStream![]> {
let messages = chat.get_messages(channel_id, 100) let messages = chat.fetch_latest_messages_desc(channel_id, 100)
.await?; // if get message returned err, inform user. .await?; // if get message returned err, inform user.
let mut rx = chat.subscribe(channel_id).await; let mut rx = chat.subscribe(channel_id).await;
let id = s.uid; let id = s.uid;
Ok(EventStream! { Ok(EventStream! {
for msg in messages { for msg in messages.into_iter().rev() {
yield Event::json(&msg); // tracing::info!("sending: {:?}", serde_json::to_string(&ChatEvent::SendMessage(msg.clone())).unwrap());
yield Event::json(&ChatEvent::SendMessage(msg));
} }
loop { loop {
select!{ select!{
_ = &mut shutdown => break, // exit early on shutdown _ = &mut shutdown => break, // exit early on shutdown
msg = rx.recv() => match msg { event = rx.recv() => match event {
Ok(msg) => { Ok(event) => {
tracing::info!("yielding message!"); // tracing::info!("yielding event: {event:?}");
yield Event::json(&msg) yield Event::json(&event)
}, },
Err(broadcast::error::RecvError::Lagged(n)) => { Err(broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!("Receiver lagging on channel {channel_id} by {n} events",); tracing::warn!("Receiver lagging on channel {channel_id} by {n} events",);
+27
View File
@@ -0,0 +1,27 @@
use rocket::serde::{Deserialize, Serialize};
use sqlx::FromRow;
use chrono::{DateTime, Utc};
use uuid::Uuid;
/// ---------- Rocket routes ----------
#[derive(Debug, Serialize, Deserialize, Clone, FromRow)]
pub struct ChatMsg {
pub id: Uuid,
pub display_name: Option<String>,
pub user_id: i64,
pub text: String,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", content = "data")]
pub enum ChatEvent {
SendMessage(ChatMsg),
/// for when a user explicitly edits a message
EditMessage { id: Uuid, msg: ChatMsg },
/// used for streaming content to a message
/// will not show up as edited
MessageAppendContent{ id: Uuid, content: String }
}
+2 -1
View File
@@ -1,3 +1,4 @@
pub mod auth; pub mod auth;
pub mod user; pub mod user;
pub mod space; pub mod space;
pub mod event;
+40 -28
View File
@@ -1,68 +1,80 @@
use crate::api::chat::ChatMsg; use crate::model::event::ChatMsg;
use crate::repo::Repo; use crate::repo::Repo;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use sqlx::PgPool; use sqlx::PgPool;
use uuid::Uuid;
#[derive(Clone)] #[derive(Clone)]
pub struct MessageRepository { pub struct MessageRepository {
pool: PgPool pool: PgPool
} }
impl Repo for MessageRepository { impl MessageRepository {
type Target = ChatMsg; pub(crate) fn new(pool: PgPool) -> Self {
fn new(pool: PgPool) -> Self {
Self { pool } Self { pool }
} }
// TODO: caching with redis // TODO: caching with redis
async fn get_by_id(&self, id: i64) -> Option<Self::Target> { async fn get_by_id(&self, id: Uuid) -> Option<ChatMsg> {
sqlx::query!( sqlx::query!(
"SELECT u.username, u.nickname, u.id as user_id, m.content, m.created_at "SELECT m.id, u.username, u.nickname, u.id as user_id, m.content, m.created_at
FROM messages m FROM messages m
JOIN users u ON m.user_id = u.id JOIN users u ON m.user_id = u.id
WHERE m.id = $1", WHERE m.id = $1",
id id
).fetch_optional(&self.pool).await.ok().flatten().map(|row| ChatMsg { ).fetch_optional(&self.pool).await.ok().flatten().map(|row| ChatMsg {
display_name: Some(row.nickname.unwrap_or(row.username)), id: row.id,
user_id: row.user_id, display_name: Some(row.nickname.unwrap_or(row.username)),
text: row.content, user_id: row.user_id,
timestamp: row.created_at, text: row.content,
}) timestamp: row.created_at,
})
} }
}
impl MessageRepository {
// TODO! caching with redis // TODO! caching with redis
pub async fn create_new( pub async fn create_new(
&self, uid: i64, channel_id: i64, &self, msg: ChatMsg, channel_id: i64
text: &str, created_at: DateTime<Utc> ) -> Result<(), sqlx::Error> {
) -> Result<i64, sqlx::Error> {
sqlx::query!( sqlx::query!(
"INSERT INTO messages (channel_id, user_id, content, created_at) "INSERT INTO messages (id, channel_id, user_id, content, created_at)
VALUES ($1, $2, $3, $4) RETURNING id", VALUES ($1, $2, $3, $4, $5)",
msg.id,
channel_id, channel_id,
uid, msg.user_id,
msg.text,
msg.timestamp
).execute(&self.pool).await.map_err(|_| sqlx::Error::RowNotFound)?;
Ok(())
}
pub async fn update_text(&self, id: Uuid, text: &str) -> Result<(), sqlx::Error> {
sqlx::query!(
"UPDATE messages SET content = $1 WHERE id = $2",
text, text,
created_at id
).fetch_optional(&self.pool).await.and_then(|row| row.map(|r| r.id).ok_or(sqlx::Error::RowNotFound)) ).execute(&self.pool).await?;
Ok(())
} }
/// TODO: caching with redis /// TODO: caching with redis
pub async fn get_by_channel(&self, channel_id: i64, limit: usize) pub async fn get_latest_by_channel_desc(
-> Result<Vec<ChatMsg>, sqlx::Error> { &self, channel_id: i64, limit: usize, page: usize
) -> Result<Vec<ChatMsg>, sqlx::Error> {
sqlx::query!( sqlx::query!(
"SELECT u.username, u.nickname, u.id as user_id, m.content, m.created_at "SELECT m.id, u.username, u.nickname, u.id as user_id, m.content, m.created_at
FROM messages m FROM messages m
JOIN users u ON m.user_id = u.id JOIN users u ON m.user_id = u.id
WHERE m.channel_id = $1 WHERE m.channel_id = $1
ORDER BY m.created_at DESC LIMIT $2", ORDER BY m.created_at DESC
LIMIT $2 OFFSET $3",
channel_id, channel_id,
limit as i64 limit as i64,
page as i64
).fetch_all(&self.pool).await.map(|messages| { ).fetch_all(&self.pool).await.map(|messages| {
messages.into_iter().rev().map(|msg| { messages.into_iter().map(|msg| {
ChatMsg { ChatMsg {
id: msg.id,
display_name: Some(msg.nickname.unwrap_or(msg.username)), display_name: Some(msg.nickname.unwrap_or(msg.username)),
user_id: msg.user_id, user_id: msg.user_id,
text: msg.content, text: msg.content,
+108 -46
View File
@@ -1,12 +1,15 @@
use crate::api::chat::ChatMsg; use crate::model::event::{ChatEvent, ChatMsg};
use crate::error::{ApiResult, AppError}; use crate::error::{ApiResult, AppError};
use crate::repo::message_repo::MessageRepository; use crate::repo::message_repo::MessageRepository;
use crate::repo::{ChannelRepo, Repo, SpaceRepo, UserRepo}; use crate::repo::{ChannelRepo, Repo, SpaceRepo, UserRepo};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use std::collections::HashMap; use std::collections::{HashMap, VecDeque};
use std::sync::Arc; use std::sync::Arc;
use std::sync::mpsc::channel;
use std::time::Instant;
use tokio::sync::broadcast::Sender; use tokio::sync::broadcast::Sender;
use tokio::sync::{broadcast, Mutex}; use tokio::sync::{broadcast, Mutex};
use uuid::Uuid;
use crate::model::space::SpaceDto; use crate::model::space::SpaceDto;
use crate::svc::llm_service::LlmService; use crate::svc::llm_service::LlmService;
@@ -15,15 +18,13 @@ use crate::svc::llm_service::LlmService;
#[derive(Clone)] #[derive(Clone)]
pub struct ChatService { pub struct ChatService {
users: Arc<dyn UserRepo>, users: Arc<dyn UserRepo>,
channels: Arc<dyn ChannelRepo>, channel_repo: Arc<dyn ChannelRepo>,
spaces: Arc<dyn SpaceRepo>, spaces: Arc<dyn SpaceRepo>,
messages: MessageRepository, messages: MessageRepository,
llm: LlmService, llm: LlmService,
buffer_size: usize, buffer_size: usize,
senders: Arc<Mutex<HashMap<i64, Sender<ChatMsg>>>>, channels: Arc<Mutex<HashMap<i64, ChannelState>>>,
} }
impl ChatService { impl ChatService {
@@ -33,13 +34,13 @@ impl ChatService {
channels: Arc<dyn ChannelRepo>, spaces: Arc<dyn SpaceRepo>, channels: Arc<dyn ChannelRepo>, spaces: Arc<dyn SpaceRepo>,
) -> Self { ) -> Self {
Self { Self {
channels, channel_repo: channels,
spaces, spaces,
llm, llm,
users, users,
messages, messages,
buffer_size, buffer_size,
senders: Arc::new(Mutex::new(std::collections::HashMap::new())), channels: Arc::new(Mutex::new(std::collections::HashMap::new())),
} }
} }
@@ -50,7 +51,7 @@ impl ChatService {
let mut result = Vec::new(); let mut result = Vec::new();
for space in spaces { for space in spaces {
let channels = self.channels.get_by_space_id(space.id).await?; let channels = self.channel_repo.get_by_space_id(space.id).await?;
result.push(SpaceDto { result.push(SpaceDto {
channels, channels,
id: space.id, id: space.id,
@@ -65,8 +66,10 @@ impl ChatService {
Ok(result) Ok(result)
} }
pub async fn get_messages(&self, channel_id: i64, limit: usize) -> ApiResult<Vec<ChatMsg>> { pub async fn fetch_latest_messages_desc(&self, channel_id: i64, limit: usize) -> ApiResult<Vec<ChatMsg>> {
let messages = self.messages.get_by_channel(channel_id, limit).await?; const PAGE: usize = 0;
let messages = self.messages.get_latest_by_channel_desc(channel_id, limit, PAGE).await?;
Ok(messages) Ok(messages)
} }
@@ -113,6 +116,7 @@ impl ChatService {
.ok_or(AppError::NotFound)?; .ok_or(AppError::NotFound)?;
let message = ChatMsg { let message = ChatMsg {
id: Uuid::new_v4(),
display_name: Some(user display_name: Some(user
.nickname.clone() .nickname.clone()
.unwrap_or_else(|| user.username.clone())), .unwrap_or_else(|| user.username.clone())),
@@ -122,66 +126,124 @@ impl ChatService {
}; };
self.publish(channel_id, message.clone()).await; self.publish(channel_id, message.clone()).await;
self.messages.create_new(message.clone(), channel_id).await?;
let _msg_id = self.messages.create_new(uid, channel_id, text, created_at).await?;
// TODO: caching w redis at repository layer // TODO: caching w redis at repository layer
let svc_instance = self.clone(); if !message.text.starts_with("/ask ") {
return Ok(());
let Some(text) = text.strip_prefix("/ask ") else {
return Ok(())
};
if !svc_instance.llm.enabled() {
return Ok(())
} }
let svc_instance = self.clone();
if !svc_instance.llm.enabled() {
return Ok(());
}
let context = self.get_history(channel_id, 25).await?;
tokio::spawn(async move { tokio::spawn(async move {
let sender = match svc_instance.channels.lock().await.get(&channel_id) {
Some(s) => s.get_sender(),
None => return,
};
let response = svc_instance.llm let response = svc_instance.llm
.query(&message) .query(&message, &context, sender)
.await; .await;
if let Ok(reply) = response { let Ok(reply) = response else {
tracing::info!("LLM reply: {}", reply.text);
svc_instance.publish(channel_id, reply.clone()).await;
// TODO: cache response (or do with redis!)
if let Err(e) = svc_instance.messages
.create_new(reply.user_id, channel_id, &reply.text, reply.timestamp).await {
tracing::error!("Failed to persist LLM reply: {}", e);
}
tracing::info!("LLM reply persisted");
} else {
tracing::warn!("Error contacting LLM: {:?}", response); tracing::warn!("Error contacting LLM: {:?}", response);
return;
};
// TODO: cache response (or do with redis!)
if let Err(e) = svc_instance.messages.create_new(reply, channel_id).await {
tracing::error!("Failed to persist LLM reply: {}", e);
} }
tracing::debug!("Full LLM reply persisted");
}); });
Ok(()) Ok(())
} }
async fn get_history(&self, channel_id: i64, limit: usize) -> ApiResult<Vec<ChatMsg>> {
let mut map = self.channels.lock().await;
if let Some(channel) = map.get(&channel_id) && !channel.cache.is_empty() {
Ok(channel.history().clone().into_iter().take(limit).collect())
} else {
let messages: Vec<_> = self.messages.get_latest_by_channel_desc(channel_id, limit, 0).await?.into_iter().rev().collect();
map.insert(channel_id,
ChannelState::new(self.buffer_size, Some(messages.clone().into()))
);
Ok(messages)
}
}
/// Subscribe to the specified channel. /// Subscribe to the specified channel.
pub async fn subscribe(&self, channel_id: i64) -> broadcast::Receiver<ChatMsg> { pub async fn subscribe(&self, channel_id: i64) -> broadcast::Receiver<ChatEvent> {
let mut map = self.senders.lock().await; let mut map = self.channels.lock().await;
let sender = map let channel = map
.entry(channel_id) .entry(channel_id)
.or_insert_with(|| broadcast::channel::<ChatMsg>(self.buffer_size).0); .or_insert_with(|| ChannelState::new(self.buffer_size, None));
sender.subscribe() channel.subscribe()
} }
// Private helper methods // Private helper methods
/// Publish a message to the specified channel. /// Publish a message to the specified channel.
async fn publish(&self, channel_id: i64, msg: ChatMsg) { async fn publish(&self, channel_id: i64, msg: ChatMsg) {
let mut map = self.senders.lock().await; let mut map = self.channels.lock().await;
let sender = map let channel = map
.entry(channel_id) .entry(channel_id)
.or_insert_with(|| broadcast::channel::<ChatMsg>(self.buffer_size).0); .or_insert_with(|| ChannelState::new(self.buffer_size, None));
let _ = sender.send(msg); channel.send(msg);
}
}
#[derive(Clone)]
pub struct ChannelState {
sender: Sender<ChatEvent>,
cache: VecDeque<ChatMsg>,
last_updated: Instant,
}
impl ChannelState {
const MAX_HISTORY_SIZE: usize = 100;
#[must_use]
pub fn new(buffer_size: usize, history: Option<VecDeque<ChatMsg>>) -> Self {
Self {
sender: broadcast::channel(buffer_size).0,
cache: history.unwrap_or_default(),
last_updated: Instant::now(),
}
} }
pub fn history(&self) -> &VecDeque<ChatMsg> {
&self.cache
}
pub fn get_sender(&self) -> Sender<ChatEvent> {
self.sender.clone()
}
#[must_use]
pub fn send(&mut self, msg: ChatMsg) {
while self.cache.len() >= Self::MAX_HISTORY_SIZE {
self.cache.pop_front();
}
self.cache.push_back(msg.clone());
if self.sender.send(ChatEvent::SendMessage(msg)).is_err() {
tracing::warn!("Sent message to channel with no subscribers");
}
}
#[must_use]
pub fn subscribe(&self) -> broadcast::Receiver<ChatEvent> {
self.sender.subscribe()
}
}
}
+119 -41
View File
@@ -1,3 +1,13 @@
use std::env;
use std::sync::LazyLock;
use futures_util::StreamExt;
use serde::{Deserialize, Serialize};
use tokio::sync::broadcast::Sender;
use uuid::Uuid;
use crate::model::event::{ChatEvent, ChatMsg};
use crate::error::{ApiResult, AppError};
#[derive(Clone)] #[derive(Clone)]
pub struct LlmService; pub struct LlmService;
@@ -15,71 +25,139 @@ impl LlmService {
LMSTUDIO_URL.is_some() LMSTUDIO_URL.is_some()
} }
pub async fn query(&self, message: &ChatMsg) -> ApiResult<ChatMsg> { pub async fn query(&self, message: &ChatMsg, context: &[ChatMsg], sender: Sender<ChatEvent>) -> ApiResult<ChatMsg> {
let Some(url) = LMSTUDIO_URL.clone() else { let Some(url) = LMSTUDIO_URL.clone() else {
return Err(AppError::internal("AI not enabled!")) return Err(AppError::internal("AI not enabled!"))
}; };
let model = LMSTUDIO_MODEL.clone().unwrap_or_else(|| "gpt-oss-20b".into()); let reply_id = Uuid::new_v4();
let timestamp = chrono::Utc::now();
let client = reqwest::Client::new(); let mut reply = ChatMsg {
id: reply_id,
// Build the request body display_name: Some("llm".into()),
let payload = LlmRequest { user_id: 0,
model, // whatever model you run locally text: String::new(),
messages: vec![Message { timestamp,
role: "user".into(),
content: message.text.clone(),
}],
}; };
let _ = sender.send(ChatEvent::SendMessage(reply.clone()));
// POST to lmstudio (default 127.0.0.1:1234) let mut messages: Vec<Message> = Vec::new();
let resp = client let system_prompt = format!(
"You are a helpful assistant in a group chat. \
You are talking to '{}'. \
Keep responses concise and conversational.",
message.display_name.as_deref().unwrap_or("unknown"),
);
messages.push(Message { role: "system".into(), content: system_prompt });
for msg in context {
let role = if msg.user_id == 0 {
"assistant" // your LLM user_id convention
} else {
"user"
};
messages.push(Message {
role: role.into(),
content: format!(
"{}: {}",
msg.display_name.as_deref().unwrap_or("unknown"),
msg.text
),
});
}
messages.push(Message {
role: "user".into(),
content: format!(
"{}: {}",
message.display_name.as_deref().unwrap_or("unknown"),
message.text.trim_start_matches("/ask ") // strip the command prefix
),
});
let Ok(resp) = reqwest::Client::new()
.post(url) .post(url)
.json(&payload) .json(&LlmRequest {
think: false,
model: LMSTUDIO_MODEL
.clone()
.unwrap_or_else(|| "gpt-oss-20b".into()),
messages,
stream: true,
})
.send() .send()
.await .await
.map_err(|_| AppError::internal("Failed to make request to LLM server"))?; else {
tracing::warn!("Failed to reach LLM");
let _ = sender.send(ChatEvent::MessageAppendContent {
id: reply_id,
content: String::from("I'm not available right now. Please try again later.")
});
return Err(AppError::internal("Failed to reach LLM"));
};
// The API returns a JSON with `choices[].message.content`
#[derive(Deserialize)] let mut full_text = String::new();
struct LlmResponse { let mut buffer = String::new();
choices: Vec<Choice>, let mut stream = resp.bytes_stream();
}
#[derive(Deserialize)] while let Some(chunk) = stream.next().await {
struct Choice { let chunk = chunk.map_err(|_| AppError::internal("Stream error"))?;
message: Message, buffer.push_str(&String::from_utf8_lossy(&chunk));
while let Some(pos) = buffer.find('\n') {
let line = buffer[..pos].trim().to_string();
buffer = buffer[pos + 1..].to_string();
if line == "data: [DONE]" {
break;
}
if let Some(json) = line.strip_prefix("data: ")
&& let Ok(parsed) = serde_json::from_str::<StreamingResponse>(json)
&& let Some(content) = parsed.choices[0].delta.content.as_ref() {
full_text.push_str(content);
let _ = sender.send(ChatEvent::MessageAppendContent {
id: reply_id,
content: content.clone(),
});
}
}
} }
let llm_resp: LlmResponse = resp reply.text = full_text;
.json()
.await
.map_err(|_| AppError::internal("Failed to parse LLM response"))?;
Ok(ChatMsg { Ok(reply)
display_name: Some(String::from("llm")),
user_id: 0,
text: llm_resp.choices[0].message.content.clone(),
timestamp: chrono::Utc::now(),
})
} }
} }
use std::env;
use std::sync::LazyLock;
// src/llm.rs
use serde::{Deserialize, Serialize};
use crate::api::chat::ChatMsg;
use crate::error::{ApiResult, AppError};
use crate::svc::chat_svc::ChatService;
#[derive(Serialize)] #[derive(Serialize)]
struct LlmRequest { struct LlmRequest {
model: String, model: String,
messages: Vec<Message>, messages: Vec<Message>,
stream: bool,
think: bool,
} }
#[derive(Deserialize)]
struct StreamingResponse {
choices: Vec<StreamingChoice>,
}
#[derive(Deserialize)]
struct StreamingChoice {
delta: Delta,
#[serde(default)]
finish_reason: Option<String>,
}
#[derive(Deserialize)]
struct Delta {
#[serde(default)]
content: Option<String>,
}
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
struct Message { struct Message {
role: String, // "user" or "assistant" role: String, // "user" or "assistant"
+2 -1
View File
@@ -3,4 +3,5 @@ pub mod chat_svc;
pub mod settings_svc; pub mod settings_svc;
pub mod user_svc; pub mod user_svc;
pub mod access_token_svc; pub mod access_token_svc;
pub mod llm_service; pub mod llm_service;
pub mod relationship_svc;
+6
View File
@@ -0,0 +1,6 @@
pub struct RelationshipService {}
impl RelationshipService {
pub fn new() -> Self { Self {} }
}
+228
View File
@@ -0,0 +1,228 @@
import argparse
import json
import threading
import time
from dataclasses import dataclass
from getpass import getpass
from typing import List
import requests
BASE_URL = "http://localhost:8000"
@dataclass
class AuthResult:
token: str
def signup(session: requests.Session, email: str, username: str, password: str, access_token: str) -> None:
url = f"{BASE_URL}/api/signup"
payload = {
"email": email,
"username": username,
"password": password,
"access_token": access_token,
}
resp = session.post(url, json=payload, timeout=10)
resp.raise_for_status()
def login(session: requests.Session, username: str, password: str) -> AuthResult:
url = f"{BASE_URL}/api/login"
payload = {
"username": username,
"password": password,
}
resp = session.post(url, json=payload, timeout=10)
resp.raise_for_status()
data = resp.json()
token = data.get("token")
if not token:
raise RuntimeError(f"Login response did not contain token: {data}")
return AuthResult(token=token)
def post_message(session: requests.Session, channel_id: int, token: str, text: str, display_name: str, user_id: int) -> None:
url = f"{BASE_URL}/api/chat/{channel_id}"
payload = {
"display_name": display_name,
"user_id": user_id,
"text": text,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
}
headers = {"Authorization": f"Bearer {token}"}
resp = session.post(url, json=payload, headers=headers, timeout=10)
resp.raise_for_status()
def read_sse_messages(
session: requests.Session,
channel_id: int,
token: str,
expected_count: int,
timeout_s: int,
capture_live_messages: threading.Event,
) -> List[dict]:
url = f"{BASE_URL}/api/events/{channel_id}"
headers = {
"Authorization": f"Bearer {token}",
"Accept": "text/event-stream",
}
received: List[dict] = []
deadline = time.monotonic() + timeout_s
try:
with session.get(url, headers=headers, stream=True, timeout=(5, timeout_s)) as resp:
resp.raise_for_status()
event_data_lines: List[str] = []
for raw_line in resp.iter_lines(decode_unicode=True):
if time.monotonic() > deadline:
break
if raw_line is None:
continue
line = raw_line.strip()
if not line:
if event_data_lines:
joined = "\n".join(event_data_lines)
event_data_lines.clear()
try:
obj = json.loads(joined)
except json.JSONDecodeError:
continue
if capture_live_messages.is_set():
received.append(obj)
if expected_count > 0 and len(received) >= expected_count:
break
else:
print(f"Discarding message: {obj}")
continue
if line.startswith("data:"):
event_data_lines.append(line[len("data:"):].strip())
except requests.exceptions.Timeout:
print("Timeout while reading SSE.")
except requests.exceptions.RequestException as exc:
print(f"Error reading SSE: {exc}")
return received
def prompt_nonempty(label: str, secret: bool = False) -> str:
while True:
value = getpass(label) if secret else input(label)
value = value.strip()
if value:
return value
print("Please enter a value.")
def main() -> int:
parser = argparse.ArgumentParser(description="Chat integration test against localhost:8000")
parser.add_argument("--existing-account", action="store_true",
help="Skip signup and only log in with an existing account")
parser.add_argument("--email", default=None)
parser.add_argument("--username", default=None)
parser.add_argument("--password", default=None)
parser.add_argument("--access-token", default=None,
help="Required only for signup mode")
parser.add_argument("--channel-id", type=int, default=1)
parser.add_argument("--message-count", type=int, default=5)
parser.add_argument("--timeout", type=int, default=15)
args = parser.parse_args()
session = requests.Session()
if args.existing_account:
username = args.username or prompt_nonempty("Username: ")
password = args.password or prompt_nonempty("Password: ", secret=True)
else:
email = args.email or prompt_nonempty("Email: ")
username = args.username or prompt_nonempty("Username: ")
password = args.password or prompt_nonempty("Password: ", secret=True)
access_token = args.access_token or prompt_nonempty("Access token: ")
print("[1/5] Signing up...")
try:
signup(session, email, username, password, access_token)
print(" signup ok")
except requests.HTTPError as e:
print(f" signup returned HTTP error: {e}")
print(" continuing to login...")
print("[2/5] Logging in...")
auth = login(session, username, password)
print(" login ok")
print(f" token: {auth.token[:12]}...")
print("[3/5] Opening event stream...")
received_messages: List[dict] = []
capture_live_messages = threading.Event()
stream_done = threading.Event()
def stream_reader() -> None:
nonlocal received_messages
try:
received_messages = read_sse_messages(
session=session,
channel_id=args.channel_id,
token=auth.token,
expected_count=args.message_count,
timeout_s=args.timeout,
capture_live_messages=capture_live_messages,
)
finally:
stream_done.set()
t = threading.Thread(target=stream_reader, daemon=True)
t.start()
# Give the server time to flush backlog on this same stream connection.
time.sleep(1.0)
print("[4/5] Starting to capture live messages and sending messages...")
capture_live_messages.set()
sent_texts = [f"Message {i}" for i in range(args.message_count)]
for i, text in enumerate(sent_texts):
post_message(
session=session,
channel_id=args.channel_id,
token=auth.token,
text=text,
display_name=username,
user_id=1,
)
print(f" sent {i + 1}/{args.message_count}: {text}")
time.sleep(0.1)
stream_done.wait(timeout=args.timeout)
t.join(timeout=1)
print("\nReceived messages:")
for i, msg in enumerate(received_messages, start=1):
print(f" {i}. {msg}")
received_texts = [m.get("text") for m in received_messages if isinstance(m, dict)]
for text in sent_texts:
if text not in received_texts:
print(f"\nFAIL: missing message: {text}")
return 1
print("\nPASS: login and message delivery test succeeded.")
return 0
if __name__ == "__main__":
raise SystemExit(main())