Compare commits

3 Commits

Author SHA1 Message Date
zxq5 d33eee1281 frontend v0.4.0 2026-04-06 01:02:39 +01:00
zxq5 7c9b733813 updated gitignore 2026-04-06 01:00:27 +01:00
zxq5 bda1ef251a full backend rewrite.
calling this v0.4.0
2026-04-06 00:57:23 +01:00
92 changed files with 3769 additions and 1649 deletions
+3
View File
@@ -1,6 +1,9 @@
*/target */target
.env .env
.log* .log*
Cargo.lock Cargo.lock
.cargo/ .cargo/
.sqlx/
docker-compose* docker-compose*
+4 -1
View File
@@ -2,7 +2,10 @@
<module type="JAVA_MODULE" version="4"> <module type="JAVA_MODULE" version="4">
<component name="NewModuleRootManager" inherit-compiler-output="true"> <component name="NewModuleRootManager" inherit-compiler-output="true">
<exclude-output /> <exclude-output />
<content url="file://$MODULE_DIR$" /> <content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/backend/src" isTestSource="false" />
<excludeFolder url="file://$MODULE_DIR$/backend/target" />
</content>
<orderEntry type="inheritedJdk" /> <orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" /> <orderEntry type="sourceFolder" forTests="false" />
</component> </component>
+17
View File
@@ -0,0 +1,17 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
<data-source source="LOCAL" name="chatapp dev" uuid="81992477-fd6f-427e-a27e-7378c26db6ef">
<driver-ref>postgresql</driver-ref>
<synchronize>true</synchronize>
<jdbc-driver>org.postgresql.Driver</jdbc-driver>
<jdbc-url>jdbc:postgresql://100.118.108.58:5432/chatapp_dev</jdbc-url>
<jdbc-additional-properties>
<property name="com.intellij.clouds.kubernetes.db.host.port" />
<property name="com.intellij.clouds.kubernetes.db.enabled" value="false" />
<property name="com.intellij.clouds.kubernetes.db.container.port" />
</jdbc-additional-properties>
<working-dir>$ProjectFileDir$</working-dir>
</data-source>
</component>
</project>
+11
View File
@@ -0,0 +1,11 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="GradleMigrationSettings" migrationVersion="1" />
<component name="GradleSettings">
<option name="linkedExternalProjectsSettings">
<GradleProjectSettings>
<option name="externalProjectPath" value="$PROJECT_DIR$/android" />
</GradleProjectSettings>
</option>
</component>
</project>
+1
View File
@@ -1,4 +1,5 @@
<project version="4"> <project version="4">
<component name="ExternalStorageConfigurationManager" enabled="true" />
<component name="ProjectRootManager" version="2" project-jdk-name="25" project-jdk-type="JavaSDK"> <component name="ProjectRootManager" version="2" project-jdk-name="25" project-jdk-type="JavaSDK">
<output url="file://$PROJECT_DIR$/out" /> <output url="file://$PROJECT_DIR$/out" />
</component> </component>
+3
View File
@@ -1,7 +1,9 @@
*.iml *.iml
.gradle .gradle
/local.properties /local.properties
/keystore.properties
/.idea/caches /.idea/caches
/.idea/.cache
/.idea/libraries /.idea/libraries
/.idea/modules.xml /.idea/modules.xml
/.idea/workspace.xml /.idea/workspace.xml
@@ -13,3 +15,4 @@
.externalNativeBuild .externalNativeBuild
.cxx .cxx
local.properties local.properties
release/
+8
View File
@@ -4,6 +4,14 @@
<selectionStates> <selectionStates>
<SelectionState runConfigName="app"> <SelectionState runConfigName="app">
<option name="selectionMode" value="DROPDOWN" /> <option name="selectionMode" value="DROPDOWN" />
<DropdownSelection timestamp="2026-04-02T14:33:39.814557661Z">
<Target type="DEFAULT_BOOT">
<handle>
<DeviceId pluginId="PhysicalDevice" identifier="serial=00319362N000094" />
</handle>
</Target>
</DropdownSelection>
<DialogSelection />
</SelectionState> </SelectionState>
<SelectionState runConfigName="MainActivity"> <SelectionState runConfigName="MainActivity">
<option name="selectionMode" value="DROPDOWN" /> <option name="selectionMode" value="DROPDOWN" />
+34 -3
View File
@@ -1,3 +1,5 @@
import java.util.Properties
plugins { plugins {
alias(libs.plugins.android.application) alias(libs.plugins.android.application)
alias(libs.plugins.kotlin.compose) alias(libs.plugins.kotlin.compose)
@@ -8,6 +10,25 @@ android {
namespace = "dev.zxq5.chatapp.android" namespace = "dev.zxq5.chatapp.android"
compileSdk = 35 compileSdk = 35
val keystorePropertiesFile = rootProject.file("local.properties")
val keystoreProperties = Properties()
if (keystorePropertiesFile.exists()) {
keystoreProperties.load(keystorePropertiesFile.inputStream())
}
signingConfigs {
create("release") {
storeFile = file("${System.getProperty("user.home")}/keystores/chatapp.jks")
storePassword = keystoreProperties["KEYSTORE_PASSWORD"] as String?
?: System.getenv("KEYSTORE_PASSWORD")
?: ""
keyAlias = "chatapp"
keyPassword = keystoreProperties["KEY_PASSWORD"] as String?
?: System.getenv("KEY_PASSWORD")
?: ""
}
}
defaultConfig { defaultConfig {
applicationId = "dev.zxq5.chatapp.android" applicationId = "dev.zxq5.chatapp.android"
minSdk = 26 minSdk = 26
@@ -20,19 +41,30 @@ android {
buildTypes { buildTypes {
release { release {
isMinifyEnabled = false isMinifyEnabled = true // shrinks code
isShrinkResources = true // removes unused resources
signingConfig = signingConfigs.getByName("release")
proguardFiles( proguardFiles(
getDefaultProguardFile("proguard-android-optimize.txt"), getDefaultProguardFile("proguard-android-optimize.txt"),
"proguard-rules.pro" "proguard-rules.pro"
) )
buildConfigField("String", "BASE_URL", "\"https://chat.zxq5.dev\"")
}
debug {
isMinifyEnabled = false
isDebuggable = true
applicationIdSuffix = ".debug" // lets you install both side by side
buildConfigField("String", "BASE_URL", "\"http://zxq5-x1:8000\"")
} }
} }
compileOptions { compileOptions {
sourceCompatibility = JavaVersion.VERSION_11 sourceCompatibility = JavaVersion.VERSION_11
targetCompatibility = JavaVersion.VERSION_11 targetCompatibility = JavaVersion.VERSION_11
} }
buildFeatures { buildFeatures {
compose = true compose = true
buildConfig = true
} }
} }
@@ -44,7 +76,6 @@ dependencies {
implementation(libs.ktor.client.auth) // Auth plugin implementation(libs.ktor.client.auth) // Auth plugin
// Kotlinx Serialization // Kotlinx Serialization
implementation(libs.kotlinx.serialization.json) implementation(libs.kotlinx.serialization.json)
// Coroutines // Coroutines
implementation(libs.kotlinx.coroutines.android) implementation(libs.kotlinx.coroutines.android)
@@ -73,4 +104,4 @@ dependencies {
androidTestImplementation(libs.androidx.compose.ui.test.junit4) androidTestImplementation(libs.androidx.compose.ui.test.junit4)
debugImplementation(libs.androidx.compose.ui.tooling) debugImplementation(libs.androidx.compose.ui.tooling)
debugImplementation(libs.androidx.compose.ui.test.manifest) debugImplementation(libs.androidx.compose.ui.test.manifest)
} }
+27 -1
View File
@@ -18,4 +18,30 @@
# If you keep the line number information, uncomment this to # If you keep the line number information, uncomment this to
# hide the original source file name. # hide the original source file name.
#-renamesourcefileattribute SourceFile #-renamesourcefileattribute SourceFile
# Ktor
-keep class io.ktor.** { *; }
-keep class kotlinx.coroutines.** { *; }
# Kotlinx serialization
-keepattributes *Annotation*, InnerClasses
-dontnote kotlinx.serialization.AnnotationsKt
-keep,includedescriptorclasses class dev.zxq5.chatapp.android.**$$serializer { *; }
-keepclassmembers class dev.zxq5.chatapp.android.** {
*** Companion;
}
-keepclasseswithmembers class dev.zxq5.chatapp.android.** {
kotlinx.serialization.KSerializer serializer(...);
}
# Keep model classes (serialization needs these)
-keep class dev.zxq5.chatapp.android.api.model.** { *; }
-keep class dev.zxq5.chatapp.android.data.model.** { *; }
# Fix for missing errorprone and javax annotations used by Tink and other libraries
-dontwarn com.google.errorprone.annotations.**
-dontwarn javax.annotation.**
# Fix for missing java.lang.management referenced by Ktor (not available on Android)
-dontwarn java.lang.management.**
+10
View File
@@ -3,6 +3,10 @@
xmlns:tools="http://schemas.android.com/tools"> xmlns:tools="http://schemas.android.com/tools">
<uses-permission android:name="android.permission.INTERNET" /> <uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.POST_NOTIFICATIONS"/>
<uses-permission android:name="android.permission.FOREGROUND_SERVICE"/>
<uses-permission android:name="android.permission.FOREGROUND_SERVICE_DATA_SYNC"/>
<application <application
android:name=".ChatApplication" android:name=".ChatApplication"
@@ -15,6 +19,12 @@
android:supportsRtl="true" android:supportsRtl="true"
android:theme="@style/Theme.Chatapp" android:theme="@style/Theme.Chatapp"
android:usesCleartextTraffic="true"> android:usesCleartextTraffic="true">
<service
android:name=".core.service.MessageStreamService"
android:foregroundServiceType="dataSync"
android:exported="false"/>
<activity <activity
android:name=".MainActivity" android:name=".MainActivity"
android:exported="true" android:exported="true"
@@ -1,6 +1,9 @@
package dev.zxq5.chatapp.android package dev.zxq5.chatapp.android
import android.app.Application import android.app.Application
import android.app.NotificationChannel
import android.app.NotificationManager
import android.os.Build
import dev.zxq5.chatapp.android.core.data.TokenStore import dev.zxq5.chatapp.android.core.data.TokenStore
import dev.zxq5.chatapp.android.data.repository.AuthRepository import dev.zxq5.chatapp.android.data.repository.AuthRepository
import dev.zxq5.chatapp.android.data.repository.ChatRepository import dev.zxq5.chatapp.android.data.repository.ChatRepository
@@ -8,6 +11,10 @@ import dev.zxq5.chatapp.android.data.repository.SettingsRepository
class ChatApplication : Application() { class ChatApplication : Application() {
object AppState {
var isInForeground = false
}
val tokenStore by lazy { TokenStore(this) } val tokenStore by lazy { TokenStore(this) }
val authRepository by lazy { AuthRepository(tokenStore) } val authRepository by lazy { AuthRepository(tokenStore) }
val chatRepository by lazy { ChatRepository(tokenStore) } val chatRepository by lazy { ChatRepository(tokenStore) }
@@ -15,5 +22,30 @@ class ChatApplication : Application() {
override fun onCreate() { override fun onCreate() {
super.onCreate() super.onCreate()
createNotificationChannels()
}
private fun createNotificationChannels() {
val messageChannel = NotificationChannel(
"messages",
"Messages",
NotificationManager.IMPORTANCE_HIGH
).apply {
description = "New message notifications"
enableVibration(true)
}
// add this — required for the foreground service persistent notification
val serviceChannel = NotificationChannel(
"service",
"Background connection",
NotificationManager.IMPORTANCE_LOW
).apply {
description = "Keeps messages running in background"
}
val manager = getSystemService(NotificationManager::class.java)
manager.createNotificationChannel(messageChannel)
manager.createNotificationChannel(serviceChannel)
} }
} }
@@ -4,25 +4,42 @@ import android.os.Bundle
import androidx.activity.ComponentActivity import androidx.activity.ComponentActivity
import androidx.activity.compose.setContent import androidx.activity.compose.setContent
import androidx.activity.enableEdgeToEdge import androidx.activity.enableEdgeToEdge
import androidx.compose.foundation.border
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.outlined.ChatBubbleOutline
import androidx.compose.material.icons.outlined.PeopleOutline
import androidx.compose.material.icons.outlined.Settings
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.NavigationBar
import androidx.compose.material3.NavigationBarItem
import androidx.compose.material3.NavigationBarItemDefaults
import androidx.compose.material3.Scaffold import androidx.compose.material3.Scaffold
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.unit.dp
import androidx.lifecycle.viewmodel.compose.viewModel import androidx.lifecycle.viewmodel.compose.viewModel
import dev.zxq5.chatapp.android.core.data.TokenStore import dev.zxq5.chatapp.android.ChatApplication.AppState
import dev.zxq5.chatapp.android.data.repository.AuthRepository import dev.zxq5.chatapp.android.core.service.MessageStreamService
import dev.zxq5.chatapp.android.data.repository.AuthState import dev.zxq5.chatapp.android.data.repository.AuthState
import dev.zxq5.chatapp.android.data.repository.ChatRepository import dev.zxq5.chatapp.android.feature.auth.AuthScreen
import dev.zxq5.chatapp.android.data.repository.SettingsRepository
import dev.zxq5.chatapp.android.feature.auth.AuthViewModel import dev.zxq5.chatapp.android.feature.auth.AuthViewModel
import dev.zxq5.chatapp.android.feature.chat.ChatScreen
import dev.zxq5.chatapp.android.feature.chat.ChatViewModel import dev.zxq5.chatapp.android.feature.chat.ChatViewModel
import dev.zxq5.chatapp.android.feature.chat.Screen import dev.zxq5.chatapp.android.feature.chat.Screen
import dev.zxq5.chatapp.android.feature.settings.SettingsViewModel import dev.zxq5.chatapp.android.feature.contacts.ContactsScreen
import dev.zxq5.chatapp.android.feature.auth.AuthScreen
import dev.zxq5.chatapp.android.feature.chat.ChatScreen
import dev.zxq5.chatapp.android.feature.settings.SettingsScreen import dev.zxq5.chatapp.android.feature.settings.SettingsScreen
import dev.zxq5.chatapp.android.feature.settings.SettingsViewModel
import dev.zxq5.chatapp.android.ui.theme.ChatappTheme import dev.zxq5.chatapp.android.ui.theme.ChatappTheme
class MainActivity : ComponentActivity() { class MainActivity : ComponentActivity() {
@@ -30,7 +47,6 @@ class MainActivity : ComponentActivity() {
super.onCreate(savedInstanceState) super.onCreate(savedInstanceState)
val app = application as ChatApplication val app = application as ChatApplication
val tokenStore = app.tokenStore
val authRepository = app.authRepository val authRepository = app.authRepository
val chatRepository = app.chatRepository val chatRepository = app.chatRepository
val settingsRepository = app.settingsRepository val settingsRepository = app.settingsRepository
@@ -44,37 +60,126 @@ class MainActivity : ComponentActivity() {
val authState by authViewModel.authState.collectAsState() val authState by authViewModel.authState.collectAsState()
val currentScreen by chatViewModel.currentScreen.collectAsState() val currentScreen by chatViewModel.currentScreen.collectAsState()
val selectedChannelId by chatViewModel.channelId.collectAsState()
Scaffold(modifier = Modifier.fillMaxSize()) { innerPadding -> LaunchedEffect(authState) {
androidx.compose.foundation.layout.Box(modifier = Modifier.padding(innerPadding)) { when (authState) {
when (authState) { AuthState.Authenticated -> MessageStreamService.start(this@MainActivity)
AuthState.Authenticated -> { AuthState.Unauthenticated -> MessageStreamService.stop(this@MainActivity)
when (currentScreen) { AuthState.AwaitingTotp -> {}
Screen.CHAT -> ChatScreen( }
viewModel = chatViewModel, }
onNavigateToSettings = { chatViewModel.navigateTo(Screen.SETTINGS) },
onLogout = { LaunchedEffect(Unit) {
authViewModel.logout() intent.getIntExtra("channel_id", -1).takeIf { it != -1 }?.let {
chatViewModel.clearChat() chatViewModel.switchChannel(it.toLong())
} }
) }
Screen.SETTINGS -> SettingsScreen(
viewModel = settingsViewModel, if (authState == AuthState.Authenticated) {
onBack = { chatViewModel.navigateTo(Screen.CHAT) }, Scaffold(
onLogout = { modifier = Modifier.fillMaxSize(),
authViewModel.logout() bottomBar = {
chatViewModel.clearChat() // Only show bottom bar if we are NOT inside a specific chat channel
} if (selectedChannelId == null) {
) BottomDock(
} currentScreen = currentScreen,
onNavigate = { chatViewModel.navigateTo(it) }
)
} }
AuthState.AwaitingTotp, AuthState.Unauthenticated -> { }
AuthScreen(viewModel = authViewModel) ) { innerPadding ->
Box(modifier = Modifier.padding(innerPadding)) {
when (currentScreen) {
Screen.CHAT -> ChatScreen(
viewModel = chatViewModel,
onNavigateToSettings = { chatViewModel.navigateTo(Screen.SETTINGS) },
onLogout = {
authViewModel.logout()
chatViewModel.clearChat()
}
)
Screen.CONTACTS -> ContactsScreen()
Screen.SETTINGS -> SettingsScreen(
viewModel = settingsViewModel,
onLogout = {
authViewModel.logout()
chatViewModel.clearChat()
}
)
} }
} }
} }
} else {
AuthScreen(viewModel = authViewModel)
} }
} }
} }
} }
override fun onResume() {
super.onResume()
AppState.isInForeground = true
}
override fun onPause() {
super.onPause()
AppState.isInForeground = false
}
override fun onNewIntent(intent: android.content.Intent) {
super.onNewIntent(intent)
intent.getIntExtra("channel_id", -1).takeIf { it != -1 }?.let { channelId ->
MessageStreamService.instance?.activeChannelId = channelId.toLong()
}
}
}
@Composable
fun BottomDock(currentScreen: Screen, onNavigate: (Screen) -> Unit) {
NavigationBar(
containerColor = MaterialTheme.colorScheme.background,
tonalElevation = 0.dp,
modifier = Modifier
.height(80.dp)
.border(
0.5.dp,
MaterialTheme.colorScheme.outlineVariant.copy(alpha = 0.2f),
RoundedCornerShape(topStart = 0.dp, topEnd = 0.dp)
)
) {
NavigationBarItem(
selected = currentScreen == Screen.CHAT,
onClick = { onNavigate(Screen.CHAT) },
icon = { Icon(Icons.Outlined.ChatBubbleOutline, contentDescription = "Chat") },
label = { Text("chat", style = MaterialTheme.typography.labelSmall) },
colors = NavigationBarItemDefaults.colors(
selectedIconColor = MaterialTheme.colorScheme.primary,
unselectedIconColor = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f),
indicatorColor = Color.Transparent
)
)
NavigationBarItem(
selected = currentScreen == Screen.CONTACTS,
onClick = { onNavigate(Screen.CONTACTS) },
icon = { Icon(Icons.Outlined.PeopleOutline, contentDescription = "Contacts") },
label = { Text("contacts", style = MaterialTheme.typography.labelSmall) },
colors = NavigationBarItemDefaults.colors(
selectedIconColor = MaterialTheme.colorScheme.primary,
unselectedIconColor = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f),
indicatorColor = Color.Transparent
)
)
NavigationBarItem(
selected = currentScreen == Screen.SETTINGS,
onClick = { onNavigate(Screen.SETTINGS) },
icon = { Icon(Icons.Outlined.Settings, contentDescription = "Settings") },
label = { Text("settings", style = MaterialTheme.typography.labelSmall) },
colors = NavigationBarItemDefaults.colors(
selectedIconColor = MaterialTheme.colorScheme.primary,
unselectedIconColor = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f),
indicatorColor = Color.Transparent
)
)
}
} }
@@ -1,10 +1,11 @@
package dev.zxq5.chatapp.android.api package dev.zxq5.chatapp.android.api
import android.util.Log import android.util.Log
import dev.zxq5.chatapp.android.BuildConfig
import dev.zxq5.chatapp.android.BuildConfig.BASE_URL
import dev.zxq5.chatapp.android.api.model.LoginRequest import dev.zxq5.chatapp.android.api.model.LoginRequest
import dev.zxq5.chatapp.android.api.model.LoginResponse import dev.zxq5.chatapp.android.api.model.LoginResponse
import dev.zxq5.chatapp.android.api.model.TOTPSixDigitCode import dev.zxq5.chatapp.android.api.model.TOTPSixDigitCode
import dev.zxq5.chatapp.android.core.BASE_URL
import dev.zxq5.chatapp.android.core.error.ApiResult import dev.zxq5.chatapp.android.core.error.ApiResult
import dev.zxq5.chatapp.android.api.model.SignupRequest import dev.zxq5.chatapp.android.api.model.SignupRequest
import io.ktor.client.HttpClient import io.ktor.client.HttpClient
@@ -1,14 +1,17 @@
package dev.zxq5.chatapp.android.api package dev.zxq5.chatapp.android.api
import dev.zxq5.chatapp.android.core.BASE_URL import dev.zxq5.chatapp.android.BuildConfig.BASE_URL
import dev.zxq5.chatapp.android.api.model.Message import dev.zxq5.chatapp.android.api.model.Message
import dev.zxq5.chatapp.android.api.model.SendMessage import dev.zxq5.chatapp.android.api.model.SendMessage
import dev.zxq5.chatapp.android.api.model.SpaceDto
import io.ktor.client.HttpClient import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.engine.android.Android import io.ktor.client.engine.android.Android
import io.ktor.client.plugins.auth.Auth import io.ktor.client.plugins.auth.Auth
import io.ktor.client.plugins.auth.providers.BearerTokens import io.ktor.client.plugins.auth.providers.BearerTokens
import io.ktor.client.plugins.auth.providers.bearer import io.ktor.client.plugins.auth.providers.bearer
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
import io.ktor.client.request.get
import io.ktor.client.request.post import io.ktor.client.request.post
import io.ktor.client.request.prepareGet import io.ktor.client.request.prepareGet
import io.ktor.client.request.setBody import io.ktor.client.request.setBody
@@ -16,12 +19,15 @@ import io.ktor.client.statement.bodyAsChannel
import io.ktor.http.ContentType import io.ktor.http.ContentType
import io.ktor.http.contentType import io.ktor.http.contentType
import io.ktor.serialization.kotlinx.json.json import io.ktor.serialization.kotlinx.json.json
import io.ktor.utils.io.readUTF8Line import io.ktor.utils.io.readLine
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.flow
import kotlinx.serialization.json.Json import kotlinx.serialization.json.Json
import kotlin.time.Clock
import kotlin.time.ExperimentalTime
class ChatClient(private val token: String) { class ChatClient(private val token: String) {
private val http = HttpClient(Android) { private val http = HttpClient(Android) {
install(ContentNegotiation) { install(ContentNegotiation) {
json(Json { ignoreUnknownKeys = true }) json(Json { ignoreUnknownKeys = true })
@@ -33,18 +39,21 @@ class ChatClient(private val token: String) {
} }
} }
suspend fun sendMessage(channelId: Int, userId: Int, text: String) { suspend fun getAccessibleChannels(): List<SpaceDto> = http.get("${BASE_URL}/api/accessible_channels").body()
@OptIn(ExperimentalTime::class)
suspend fun sendMessage(channelId: Long, userId: Int, text: String) {
http.post("${BASE_URL}/api/chat/$channelId") { http.post("${BASE_URL}/api/chat/$channelId") {
contentType(ContentType.Application.Json) contentType(ContentType.Application.Json)
setBody(SendMessage(user_id = userId, text = text, timestamp = System.currentTimeMillis())) setBody(SendMessage(user_id = userId, text = text, timestamp = Clock.System.now()))
} }
} }
fun messageStream(channelId: Int): Flow<Message> = flow { fun messageStream(channelId: Long): Flow<Message> = flow {
http.prepareGet("${BASE_URL}/api/events/$channelId").execute { response -> http.prepareGet("${BASE_URL}/api/events/$channelId").execute { response ->
val channel = response.bodyAsChannel() val channel = response.bodyAsChannel()
while (!channel.isClosedForRead) { while (!channel.isClosedForRead) {
val line = channel.readUTF8Line(256) ?: break val line = channel.readLine() ?: break
if (line.startsWith("data:")) { if (line.startsWith("data:")) {
val json = line.removePrefix("data:").trim() val json = line.removePrefix("data:").trim()
runCatching { Json.decodeFromString<Message>(json) } runCatching { Json.decodeFromString<Message>(json) }
@@ -54,4 +63,3 @@ class ChatClient(private val token: String) {
} }
} }
} }
@@ -1,6 +1,7 @@
package dev.zxq5.chatapp.android.api package dev.zxq5.chatapp.android.api
import android.util.Log import android.util.Log
import dev.zxq5.chatapp.android.BuildConfig.BASE_URL
import dev.zxq5.chatapp.android.api.model.AccountDeleteRequest import dev.zxq5.chatapp.android.api.model.AccountDeleteRequest
import dev.zxq5.chatapp.android.api.model.DisplayNameRequest import dev.zxq5.chatapp.android.api.model.DisplayNameRequest
import dev.zxq5.chatapp.android.api.model.PasswordChangeRequest import dev.zxq5.chatapp.android.api.model.PasswordChangeRequest
@@ -10,7 +11,6 @@ import dev.zxq5.chatapp.android.api.model.TotpStatus
import dev.zxq5.chatapp.android.api.model.UsernameRequest import dev.zxq5.chatapp.android.api.model.UsernameRequest
import dev.zxq5.chatapp.android.api.model.TotpDeleteRequest import dev.zxq5.chatapp.android.api.model.TotpDeleteRequest
import dev.zxq5.chatapp.android.api.model.PasswordRequest import dev.zxq5.chatapp.android.api.model.PasswordRequest
import dev.zxq5.chatapp.android.core.BASE_URL
import dev.zxq5.chatapp.android.core.error.ApiResult import dev.zxq5.chatapp.android.core.error.ApiResult
import io.ktor.client.HttpClient import io.ktor.client.HttpClient
import io.ktor.client.call.body import io.ktor.client.call.body
@@ -0,0 +1,15 @@
package dev.zxq5.chatapp.android.api.model
import kotlinx.serialization.Serializable
import kotlin.time.ExperimentalTime
import kotlin.time.Instant
@Serializable
data class Channel @OptIn(ExperimentalTime::class) constructor(
val id: Long,
val name: String,
val description: String? = null,
val space_id: Long,
val created_at: Instant,
val updated_at: Instant
)
@@ -1,11 +1,13 @@
package dev.zxq5.chatapp.android.api.model package dev.zxq5.chatapp.android.api.model
import kotlinx.serialization.Serializable import kotlinx.serialization.Serializable
import kotlin.time.ExperimentalTime
import kotlin.time.Instant
@Serializable @Serializable
data class Message( data class Message @OptIn(ExperimentalTime::class) constructor(
val user_id: Int, val user_id: Int,
val display_name: String, val display_name: String,
val text: String, val text: String,
val timestamp: Long val timestamp: Instant
) )
@@ -1,10 +1,12 @@
package dev.zxq5.chatapp.android.api.model package dev.zxq5.chatapp.android.api.model
import kotlinx.serialization.Serializable import kotlinx.serialization.Serializable
import kotlin.time.ExperimentalTime
import kotlin.time.Instant
@Serializable @Serializable
data class SendMessage( data class SendMessage @OptIn(ExperimentalTime::class) constructor(
val user_id: Int, val user_id: Int,
val text: String, val text: String,
val timestamp: Long val timestamp: Instant
) )
@@ -0,0 +1,28 @@
package dev.zxq5.chatapp.android.api.model
import kotlinx.serialization.Serializable
import kotlin.time.ExperimentalTime
import kotlin.time.Instant
@Serializable
data class Space @OptIn(ExperimentalTime::class) constructor(
val id: Long,
val name: String,
val description: String? = null,
val owner_id: Long,
val created_at: Instant,
val updated_at: Instant
)
@Serializable
data class SpaceDto @OptIn(ExperimentalTime::class) constructor(
val channels: List<Channel>,
val id: Long,
val name: String,
val description: String? = null,
val owner_id: Long,
val created_at: Instant,
val updated_at: Instant
)
@@ -1,3 +1,3 @@
package dev.zxq5.chatapp.android.core package dev.zxq5.chatapp.android.core
const val BASE_URL = "http://zxq5-x1:8000" //const val BASE_URL = "http://zxq5-x1:8000"
@@ -0,0 +1,104 @@
package dev.zxq5.chatapp.android.core.service
import android.app.Service
import android.content.Context
import android.content.Intent
import android.os.Build
import android.os.IBinder
import android.util.Log
import dev.zxq5.chatapp.android.ChatApplication
import dev.zxq5.chatapp.android.data.repository.ChatRepository
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.cancel
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.launch
// core/service/MessageStreamService.kt
class MessageStreamService : Service() {
private val serviceScope = CoroutineScope(SupervisorJob() + Dispatchers.IO)
private lateinit var notificationService: NotificationService
private lateinit var chatRepository: ChatRepository
// which channel the user is currently looking at
// set by the ViewModel when the user opens/closes a channel
var activeChannelId: Long? = null
set(value) {
field = value
Log.d("Service", "activeChannelId set to $value")
if (value != null) {
// restart stream with new channel
currentStreamJob?.cancel()
observeMessages()
}
}
private var currentStreamJob: kotlinx.coroutines.Job? = null
companion object {
var instance: MessageStreamService? = null
fun start(context: Context) {
val intent = Intent(context, MessageStreamService::class.java)
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
context.startForegroundService(intent)
} else {
context.startService(intent)
}
}
fun stop(context: Context) {
context.stopService(Intent(context, MessageStreamService::class.java))
}
}
override fun onCreate() {
super.onCreate()
instance = this
notificationService = NotificationService(this)
chatRepository = (application as ChatApplication).chatRepository
}
override fun onStartCommand(intent: Intent?, flags: Int, startId: Int): Int {
startForeground(
NotificationService.FOREGROUND_NOTIFICATION_ID,
notificationService.buildForegroundNotification()
)
observeMessages()
return START_STICKY // restart if killed
}
private fun observeMessages() {
val channelId = activeChannelId ?: chatRepository.getLastActiveChannel()
Log.d("Service", "observeMessages called, channelId=$channelId")
if (channelId == null) {
Log.d("Service", "No channel to observe, waiting for switchChannel")
return
}
Log.d("Service", "Starting stream for channel $channelId")
currentStreamJob = serviceScope.launch {
chatRepository.messageStream(channelId)
.catch { e -> Log.e("Service", "Stream error", e) }
.collect { message ->
if (!ChatApplication.AppState.isInForeground) { // no channel focused, always notify
notificationService.showMessageNotification(
conversationId = activeChannelId.toString(),
senderName = message.display_name,
messagePreview = message.text.take(80)
)
}
}
}
}
override fun onBind(intent: Intent?): IBinder? = null
override fun onDestroy() {
super.onDestroy()
instance = null
serviceScope.cancel()
}
}
@@ -0,0 +1,94 @@
package dev.zxq5.chatapp.android.core.service
import android.app.Notification
import android.app.NotificationChannel
import android.app.NotificationManager
import android.app.PendingIntent
import android.content.Context
import android.content.Intent
import android.os.Build
import androidx.core.app.NotificationCompat
import dev.zxq5.chatapp.android.MainActivity
import dev.zxq5.chatapp.android.R
class NotificationService(private val context: Context) {
companion object {
const val CHANNEL_ID = "messages"
const val FOREGROUND_NOTIFICATION_ID = 1 // ← this needs to exist
}
private val manager = context.getSystemService(NotificationManager::class.java)
fun createChannels() {
// channel for new message notifications
val messageChannel = NotificationChannel(
CHANNEL_ID,
"Messages",
NotificationManager.IMPORTANCE_HIGH
).apply {
enableVibration(true)
}
// channel for the persistent foreground service notification
// low importance so it doesn't make noise
val serviceChannel = NotificationChannel(
"service",
"Background connection",
NotificationManager.IMPORTANCE_LOW
)
val mgr = context.getSystemService(NotificationManager::class.java)
mgr.createNotificationChannel(messageChannel)
mgr.createNotificationChannel(serviceChannel)
}
fun buildForegroundNotification(): Notification {
return NotificationCompat.Builder(context, "service")
.setSmallIcon(R.drawable.ic_notification)
.setContentTitle("chatapp")
.setContentText("Connected")
.setOngoing(true)
.setSilent(true)
.build()
}
fun showMessageNotification(
conversationId: String,
senderName: String,
messagePreview: String, // for E2E this would be "New message" — no plaintext
notificationId: Int = conversationId.hashCode()
) {
// intent that opens the app to the right conversation when tapped
val intent = Intent(context, MainActivity::class.java).apply {
flags = Intent.FLAG_ACTIVITY_SINGLE_TOP
putExtra("conversation_id", conversationId)
}
val pendingIntent = PendingIntent.getActivity(
context,
notificationId,
intent,
PendingIntent.FLAG_UPDATE_CURRENT or PendingIntent.FLAG_IMMUTABLE
)
val notification = NotificationCompat.Builder(context, "messages")
.setSmallIcon(R.drawable.ic_notification)
.setContentTitle(senderName)
.setContentText(messagePreview)
.setPriority(NotificationCompat.PRIORITY_HIGH)
.setContentIntent(pendingIntent)
.setAutoCancel(true) // dismiss on tap
.build()
manager.notify(notificationId, notification)
}
fun dismissNotification(conversationId: String) {
manager.cancel(conversationId.hashCode())
}
fun dismissAll() {
manager.cancelAll()
}
}
@@ -3,6 +3,7 @@ package dev.zxq5.chatapp.android.data.repository
import dev.zxq5.chatapp.android.api.ChatClient import dev.zxq5.chatapp.android.api.ChatClient
import dev.zxq5.chatapp.android.core.data.TokenStore import dev.zxq5.chatapp.android.core.data.TokenStore
import dev.zxq5.chatapp.android.api.model.Message import dev.zxq5.chatapp.android.api.model.Message
import dev.zxq5.chatapp.android.api.model.SpaceDto
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.emptyFlow import kotlinx.coroutines.flow.emptyFlow
@@ -11,6 +12,8 @@ class ChatRepository(private val tokenStore: TokenStore) {
private var _chatClient: ChatClient? = null private var _chatClient: ChatClient? = null
private var _lastToken: String? = null private var _lastToken: String? = null
private var _lastActiveChannel: Long? = null
private fun getChatClient(): ChatClient? { private fun getChatClient(): ChatClient? {
val token = tokenStore.get() ?: return null val token = tokenStore.get() ?: return null
if (_chatClient == null || token != _lastToken) { if (_chatClient == null || token != _lastToken) {
@@ -25,14 +28,23 @@ class ChatRepository(private val tokenStore: TokenStore) {
_lastToken = null _lastToken = null
} }
fun getLastActiveChannel(): Long? {
return _lastActiveChannel
}
fun getUserId() = tokenStore.getUserId() fun getUserId() = tokenStore.getUserId()
suspend fun sendMessage(channelId: Int, text: String) { suspend fun getAccessibleChannels(): List<SpaceDto> {
return getChatClient()?.getAccessibleChannels() ?: emptyList()
}
suspend fun sendMessage(channelId: Long, text: String) {
val userId = tokenStore.getUserId() ?: return val userId = tokenStore.getUserId() ?: return
getChatClient()?.sendMessage(channelId, userId, text) getChatClient()?.sendMessage(channelId, userId, text)
} }
fun messageStream(channelId: Int): Flow<Message> { fun messageStream(channelId: Long): Flow<Message> {
_lastActiveChannel = channelId
return getChatClient()?.messageStream(channelId) ?: emptyFlow() return getChatClient()?.messageStream(channelId) ?: emptyFlow()
} }
} }
@@ -2,6 +2,7 @@ package dev.zxq5.chatapp.android.feature.auth
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import dev.zxq5.chatapp.android.core.service.MessageStreamService
import dev.zxq5.chatapp.android.data.repository.AuthRepository import dev.zxq5.chatapp.android.data.repository.AuthRepository
import dev.zxq5.chatapp.android.data.repository.LoginResult import dev.zxq5.chatapp.android.data.repository.LoginResult
import dev.zxq5.chatapp.android.data.repository.SignupResult import dev.zxq5.chatapp.android.data.repository.SignupResult
@@ -3,8 +3,12 @@ package dev.zxq5.chatapp.android.feature.chat
import android.util.Log import android.util.Log
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import dev.zxq5.chatapp.android.api.model.Channel
import dev.zxq5.chatapp.android.data.repository.ChatRepository import dev.zxq5.chatapp.android.data.repository.ChatRepository
import dev.zxq5.chatapp.android.api.model.Message import dev.zxq5.chatapp.android.api.model.Message
import dev.zxq5.chatapp.android.api.model.Space
import dev.zxq5.chatapp.android.api.model.SpaceDto
import dev.zxq5.chatapp.android.core.service.MessageStreamService
import kotlinx.coroutines.Job import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
@@ -12,15 +16,13 @@ import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.update import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() { class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
private val _messages = MutableStateFlow<List<Message>>(emptyList()) private val _messages = MutableStateFlow<List<Message>>(emptyList())
val messages: StateFlow<List<Message>> = _messages val messages: StateFlow<List<Message>> = _messages
private val _channelId = MutableStateFlow<Int?>(null) private val _channelId = MutableStateFlow<Long?>(null)
val channelId: StateFlow<Int?> = _channelId val channelId: StateFlow<Long?> = _channelId
private val _currentScreen = MutableStateFlow(Screen.CHAT) private val _currentScreen = MutableStateFlow(Screen.CHAT)
val currentScreen: StateFlow<Screen> = _currentScreen val currentScreen: StateFlow<Screen> = _currentScreen
@@ -28,11 +30,35 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
private val _currentUserId = MutableStateFlow<Int?>(null) private val _currentUserId = MutableStateFlow<Int?>(null)
val currentUserId: StateFlow<Int?> = _currentUserId val currentUserId: StateFlow<Int?> = _currentUserId
private val _spaces = MutableStateFlow<List<SpaceDto>>(emptyList())
val spaces: StateFlow<List<SpaceDto>> = _spaces
private val _error = MutableStateFlow<String?>(null)
val error: StateFlow<String?> = _error
private val _channelError = MutableStateFlow<String?>(null)
val channelError: StateFlow<String?> = _channelError
private var streamJob: Job? = null private var streamJob: Job? = null
init { init {
_currentUserId.value = chatRepository.getUserId() _currentUserId.value = chatRepository.getUserId()
observeChannel() observeChannel()
loadAccessibleChannels()
}
fun loadAccessibleChannels() {
_error.value = null
viewModelScope.launch {
runCatching {
chatRepository.getAccessibleChannels()
}.onSuccess { data ->
_spaces.value = data
}.onFailure { e ->
Log.e("Chat", "Failed to load spaces", e)
_error.value = "Failed to load channels: ${e.message}"
}
}
} }
private fun observeChannel() { private fun observeChannel() {
@@ -40,11 +66,13 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
_channelId.collect { id -> _channelId.collect { id ->
streamJob?.cancel() streamJob?.cancel()
_messages.value = emptyList() _messages.value = emptyList()
_channelError.value = null
if (id != null) { if (id != null) {
streamJob = launch { streamJob = launch {
chatRepository.messageStream(id) chatRepository.messageStream(id)
.catch { e -> .catch { e ->
Log.e("Chat", "Stream error", e) Log.e("Chat", "Stream error", e)
_channelError.value = "Connection lost: ${e.message}"
} }
.collect { message -> .collect { message ->
_messages.update { it + message } _messages.update { it + message }
@@ -59,12 +87,14 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
_currentScreen.value = screen _currentScreen.value = screen
} }
fun switchChannel(id: Int?) { fun switchChannel(id: Long?) {
_channelId.value = id _channelId.value = id
MessageStreamService.instance?.activeChannelId = id
if (id != null) { if (id != null) {
// Refresh user ID just in case it wasn't available at init // Refresh user ID just in case it wasn't available at init
_currentUserId.value = chatRepository.getUserId() _currentUserId.value = chatRepository.getUserId()
chatRepository.resetClient()
} }
} }
@@ -78,6 +108,7 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
) )
}.onFailure { e -> }.onFailure { e ->
Log.e("Chat", "Send message error", e) Log.e("Chat", "Send message error", e)
_channelError.value = "Failed to send message"
} }
} }
} }
@@ -86,8 +117,14 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
_messages.value = emptyList() _messages.value = emptyList()
_channelId.value = null _channelId.value = null
_currentUserId.value = null _currentUserId.value = null
_error.value = null
_channelError.value = null
streamJob?.cancel() streamJob?.cancel()
chatRepository.resetClient() chatRepository.resetClient()
MessageStreamService.instance?.activeChannelId = null
}
fun clearChannelError() {
_channelError.value = null
} }
} }
@@ -1,5 +1,5 @@
package dev.zxq5.chatapp.android.feature.chat package dev.zxq5.chatapp.android.feature.chat
enum class Screen { enum class Screen {
CHAT, SETTINGS CHAT, CONTACTS, SETTINGS
} }
@@ -28,22 +28,21 @@ import androidx.compose.foundation.text.KeyboardOptions
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.filled.ArrowBack import androidx.compose.material.icons.automirrored.filled.ArrowBack
import androidx.compose.material.icons.filled.Add import androidx.compose.material.icons.filled.Add
import androidx.compose.material.icons.filled.Refresh
import androidx.compose.material.icons.filled.Send
import androidx.compose.material3.Button
import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.HorizontalDivider import androidx.compose.material3.HorizontalDivider
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton import androidx.compose.material3.IconButton
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.NavigationBar
import androidx.compose.material3.NavigationBarItem
import androidx.compose.material3.NavigationBarItemDefaults
import androidx.compose.material3.Scaffold import androidx.compose.material3.Scaffold
import androidx.compose.material3.SnackbarHost
import androidx.compose.material3.SnackbarHostState
import androidx.compose.material3.Surface import androidx.compose.material3.Surface
import androidx.compose.material3.Text import androidx.compose.material3.Text
import androidx.compose.material3.TopAppBar import androidx.compose.material3.TopAppBar
import androidx.compose.material3.TopAppBarDefaults import androidx.compose.material3.TopAppBarDefaults
import androidx.compose.material.icons.filled.Send
import androidx.compose.material.icons.outlined.ChatBubbleOutline
import androidx.compose.material.icons.outlined.Settings
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState import androidx.compose.runtime.collectAsState
@@ -56,26 +55,29 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.SolidColor import androidx.compose.ui.graphics.SolidColor
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.input.ImeAction import androidx.compose.ui.text.input.ImeAction
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.unit.Dp import androidx.compose.ui.unit.Dp
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import dev.zxq5.chatapp.android.api.model.Channel
import dev.zxq5.chatapp.android.api.model.Message import dev.zxq5.chatapp.android.api.model.Message
import java.text.DateFormat import java.text.DateFormat
import java.util.Date import java.util.Date
import kotlin.time.ExperimentalTime
@Composable @Composable
fun ChatScreen( fun ChatScreen(
viewModel: ChatViewModel, viewModel: ChatViewModel,
onNavigateToSettings: () -> Unit, onNavigateToSettings: () -> Unit,
onLogout: () -> Unit // Note: logout is now part of SettingsScreen in this UI, but we'll keep the param for now onLogout: () -> Unit
) { ) {
val selectedChannelId by viewModel.channelId.collectAsState() val selectedChannelId by viewModel.channelId.collectAsState()
if (selectedChannelId == null) { if (selectedChannelId == null) {
ChannelListScreen( ChannelListScreen(
viewModel = viewModel, viewModel = viewModel,
onChannelSelect = { viewModel.switchChannel(it) }, onChannelSelect = { viewModel.switchChannel(it) }
onNavigateToSettings = onNavigateToSettings
) )
} else { } else {
MessageScreen( MessageScreen(
@@ -90,20 +92,15 @@ fun ChatScreen(
@Composable @Composable
fun ChannelListScreen( fun ChannelListScreen(
viewModel: ChatViewModel, viewModel: ChatViewModel,
onChannelSelect: (Int) -> Unit, onChannelSelect: (Long) -> Unit
onNavigateToSettings: () -> Unit
) { ) {
val spaces by viewModel.spaces.collectAsState()
val error by viewModel.error.collectAsState()
Scaffold( Scaffold(
containerColor = MaterialTheme.colorScheme.background, containerColor = MaterialTheme.colorScheme.background,
topBar = { topBar = {
Column { Column {
Spacer(Modifier.height(8.dp))
Text(
"contacts",
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f),
modifier = Modifier.padding(horizontal = 20.dp)
)
TopAppBar( TopAppBar(
title = { title = {
Text( Text(
@@ -115,103 +112,69 @@ fun ChannelListScreen(
colors = TopAppBarDefaults.topAppBarColors( colors = TopAppBarDefaults.topAppBarColors(
containerColor = Color.Transparent, containerColor = Color.Transparent,
titleContentColor = MaterialTheme.colorScheme.onSurface titleContentColor = MaterialTheme.colorScheme.onSurface
) ),
windowInsets = androidx.compose.foundation.layout.WindowInsets(0, 0, 0, 0),
) )
Text( Text(
"5 channels · end-to-end encrypted", "Public channels - dms coming soon.",
style = MaterialTheme.typography.labelSmall, style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.4f), color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.4f),
modifier = Modifier.padding(horizontal = 20.dp, vertical = 2.dp) modifier = Modifier.padding(horizontal = 20.dp, vertical = 2.dp)
) )
Spacer(Modifier.height(12.dp)) Spacer(Modifier.height(12.dp))
}
Row( }
modifier = Modifier ) { padding ->
.fillMaxWidth() if (error != null) {
.background(MaterialTheme.colorScheme.secondary.copy(alpha = 0.2f)) Column(
.padding(horizontal = 20.dp, vertical = 8.dp), modifier = Modifier.fillMaxSize().padding(padding),
verticalAlignment = Alignment.CenterVertically horizontalAlignment = Alignment.CenterHorizontally,
) { verticalArrangement = Arrangement.Center
Box( ) {
modifier = Modifier Text(
.size(6.dp) text = error!!,
.clip(CircleShape) color = MaterialTheme.colorScheme.error,
.background(MaterialTheme.colorScheme.primary) textAlign = TextAlign.Center,
) modifier = Modifier.padding(16.dp)
Spacer(Modifier.width(10.dp)) )
Text( Button(onClick = { viewModel.loadAccessibleChannels() }) {
"global · walkie talkie", Icon(Icons.Default.Refresh, contentDescription = null)
style = MaterialTheme.typography.labelSmall, Spacer(Modifier.width(8.dp))
color = MaterialTheme.colorScheme.onSurfaceVariant, Text("Retry")
modifier = Modifier.weight(1f)
)
Surface(
color = Color.Transparent,
border = border(0.5.dp, MaterialTheme.colorScheme.outlineVariant),
shape = RoundedCornerShape(4.dp)
) {
Text(
"hold to talk",
modifier = Modifier.padding(horizontal = 8.dp, vertical = 3.dp),
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
} }
} }
}, } else {
bottomBar = { BottomDock(viewModel, onNavigateToSettings) } LazyColumn(modifier = Modifier.padding(padding).fillMaxSize()) {
) { padding -> spaces.forEach { spaceDto ->
LazyColumn(modifier = Modifier.padding(padding).fillMaxSize()) { item {
items(10) { i -> Text(
val id = i + 1 text = spaceDto.name.lowercase(),
ChannelItem(id = id, onClick = { onChannelSelect(id) }) modifier = Modifier.padding(horizontal = 20.dp, vertical = 8.dp),
HorizontalDivider( style = MaterialTheme.typography.labelMedium,
modifier = Modifier.padding(horizontal = 20.dp), color = MaterialTheme.colorScheme.primary,
thickness = 0.5.dp, fontWeight = FontWeight.Bold
color = MaterialTheme.colorScheme.outlineVariant.copy(alpha = 0.2f) )
) }
items(spaceDto.channels) { channel ->
ChannelItem(channel = channel, onClick = { onChannelSelect(channel.id) })
HorizontalDivider(
modifier = Modifier.padding(horizontal = 20.dp),
thickness = 0.5.dp,
color = MaterialTheme.colorScheme.outlineVariant.copy(alpha = 0.2f)
)
}
item {
Spacer(Modifier.height(16.dp))
}
}
} }
} }
} }
} }
@Composable @Composable
fun BottomDock(viewModel: ChatViewModel, onNavigateToSettings: () -> Unit) { fun ChannelItem(channel: Channel, onClick: () -> Unit) {
val currentScreen by viewModel.currentScreen.collectAsState()
NavigationBar(
containerColor = MaterialTheme.colorScheme.background,
tonalElevation = 0.dp,
modifier = Modifier.border(0.5.dp, MaterialTheme.colorScheme.outlineVariant.copy(alpha = 0.2f), RoundedCornerShape(topStart = 0.dp, topEnd = 0.dp))
) {
NavigationBarItem(
selected = currentScreen == Screen.CHAT,
onClick = { viewModel.navigateTo(Screen.CHAT) },
icon = { Icon(Icons.Outlined.ChatBubbleOutline, contentDescription = "Chat") },
label = { Text("chat", style = MaterialTheme.typography.labelSmall) },
colors = NavigationBarItemDefaults.colors(
selectedIconColor = MaterialTheme.colorScheme.primary,
unselectedIconColor = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f),
indicatorColor = Color.Transparent
)
)
NavigationBarItem(
selected = currentScreen == Screen.SETTINGS,
onClick = onNavigateToSettings,
icon = { Icon(Icons.Outlined.Settings, contentDescription = "Settings") },
label = { Text("settings", style = MaterialTheme.typography.labelSmall) },
colors = NavigationBarItemDefaults.colors(
selectedIconColor = MaterialTheme.colorScheme.primary,
unselectedIconColor = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f),
indicatorColor = Color.Transparent
)
)
}
}
@Composable
fun ChannelItem(id: Int, onClick: () -> Unit) {
Row( Row(
modifier = Modifier modifier = Modifier
.fillMaxWidth() .fillMaxWidth()
@@ -227,7 +190,7 @@ fun ChannelItem(id: Int, onClick: () -> Unit) {
contentAlignment = Alignment.Center contentAlignment = Alignment.Center
) { ) {
Text( Text(
"C$id", channel.name.take(1).uppercase(),
style = MaterialTheme.typography.labelSmall, style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onSurfaceVariant color = MaterialTheme.colorScheme.onSurfaceVariant
) )
@@ -235,31 +198,30 @@ fun ChannelItem(id: Int, onClick: () -> Unit) {
Spacer(Modifier.width(12.dp)) Spacer(Modifier.width(12.dp))
Column(modifier = Modifier.weight(1f)) { Column(modifier = Modifier.weight(1f)) {
Text( Text(
text = "channel $id", text = channel.name,
style = MaterialTheme.typography.bodyLarge, style = MaterialTheme.typography.bodyLarge,
color = MaterialTheme.colorScheme.onSurface color = MaterialTheme.colorScheme.onSurface
) )
Text( if (channel.description != null) {
text = "tap to join", Text(
style = MaterialTheme.typography.labelSmall, text = channel.description,
color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f) style = MaterialTheme.typography.labelSmall,
) color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f)
)
}
} }
Text(
"14:22",
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.3f)
)
} }
} }
@OptIn(ExperimentalMaterial3Api::class) @OptIn(ExperimentalMaterial3Api::class)
@Composable @Composable
fun MessageScreen(channelId: Int, viewModel: ChatViewModel, onBack: () -> Unit) { fun MessageScreen(channelId: Long, viewModel: ChatViewModel, onBack: () -> Unit) {
val messages by viewModel.messages.collectAsState() val messages by viewModel.messages.collectAsState()
val currentUserId by viewModel.currentUserId.collectAsState() val currentUserId by viewModel.currentUserId.collectAsState()
val channelError by viewModel.channelError.collectAsState()
var input by remember { mutableStateOf("") } var input by remember { mutableStateOf("") }
val listState = rememberLazyListState() val listState = rememberLazyListState()
val snackbarHostState = remember { SnackbarHostState() }
LaunchedEffect(messages.size) { LaunchedEffect(messages.size) {
if (messages.isNotEmpty()) { if (messages.isNotEmpty()) {
@@ -267,8 +229,16 @@ fun MessageScreen(channelId: Int, viewModel: ChatViewModel, onBack: () -> Unit)
} }
} }
LaunchedEffect(channelError) {
channelError?.let {
snackbarHostState.showSnackbar(it)
viewModel.clearChannelError()
}
}
Scaffold( Scaffold(
containerColor = MaterialTheme.colorScheme.background, containerColor = MaterialTheme.colorScheme.background,
snackbarHost = { SnackbarHost(snackbarHostState) },
topBar = { topBar = {
TopAppBar( TopAppBar(
title = { title = {
@@ -294,6 +264,7 @@ fun MessageScreen(channelId: Int, viewModel: ChatViewModel, onBack: () -> Unit)
) )
} }
}, },
windowInsets = androidx.compose.foundation.layout.WindowInsets(0, 0, 0, 0),
colors = TopAppBarDefaults.topAppBarColors( colors = TopAppBarDefaults.topAppBarColors(
containerColor = Color.Transparent containerColor = Color.Transparent
) )
@@ -391,10 +362,13 @@ fun MessageScreen(channelId: Int, viewModel: ChatViewModel, onBack: () -> Unit)
} }
} }
@OptIn(ExperimentalTime::class)
@Composable @Composable
fun MessageBubble(message: Message, currentUserId: Int?) { fun MessageBubble(message: Message, currentUserId: Int?) {
val time = remember(message.timestamp) { val time = remember(message.timestamp) {
DateFormat.getTimeInstance(DateFormat.SHORT).format(Date(message.timestamp)).lowercase() DateFormat.getTimeInstance(DateFormat.SHORT)
.format(Date(message.timestamp.toEpochMilliseconds()))
.lowercase()
} }
val isMe = currentUserId != null && message.user_id == currentUserId val isMe = currentUserId != null && message.user_id == currentUserId
@@ -0,0 +1,52 @@
package dev.zxq5.chatapp.android.feature.contacts
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.padding
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Scaffold
import androidx.compose.material3.Text
import androidx.compose.material3.TopAppBar
import androidx.compose.material3.TopAppBarDefaults
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun ContactsScreen() {
Scaffold(
containerColor = MaterialTheme.colorScheme.background,
topBar = {
TopAppBar(
title = {
Text(
"contacts",
style = MaterialTheme.typography.titleLarge,
color = MaterialTheme.colorScheme.onSurface
)
},
windowInsets = androidx.compose.foundation.layout.WindowInsets(0, 0, 0, 0),
colors = TopAppBarDefaults.topAppBarColors(
containerColor = Color.Transparent,
titleContentColor = MaterialTheme.colorScheme.onSurface
)
)
}
) { padding ->
Box(
modifier = Modifier
.fillMaxSize()
.padding(padding),
contentAlignment = Alignment.Center
) {
Text(
text = "Contacts coming soon",
style = MaterialTheme.typography.bodyLarge,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
}
@@ -63,7 +63,6 @@ import androidx.compose.ui.text.style.TextAlign
@Composable @Composable
fun SettingsScreen( fun SettingsScreen(
viewModel: SettingsViewModel, viewModel: SettingsViewModel,
onBack: () -> Unit,
onLogout: () -> Unit onLogout: () -> Unit
) { ) {
val is2faEnabled by viewModel.is2faEnabled.collectAsState() val is2faEnabled by viewModel.is2faEnabled.collectAsState()
@@ -88,15 +87,7 @@ fun SettingsScreen(
color = MaterialTheme.colorScheme.onSurface color = MaterialTheme.colorScheme.onSurface
) )
}, },
navigationIcon = { windowInsets = androidx.compose.foundation.layout.WindowInsets(0, 0, 0, 0),
IconButton(onClick = onBack) {
Icon(
Icons.AutoMirrored.Filled.ArrowBack,
contentDescription = "Back",
tint = MaterialTheme.colorScheme.onSurfaceVariant
)
}
},
colors = TopAppBarDefaults.topAppBarColors(containerColor = Color.Transparent) colors = TopAppBarDefaults.topAppBarColors(containerColor = Color.Transparent)
) )
} }
@@ -0,0 +1,5 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android" android:autoMirrored="true" android:height="24dp" android:tint="#FFFFFF" android:viewportHeight="24" android:viewportWidth="24" android:width="24dp">
<path android:fillColor="@android:color/white" android:pathData="M20,2L4,2c-1.1,0 -1.99,0.9 -1.99,2L2,22l4,-4h14c1.1,0 2,-0.9 2,-2L22,4c0,-1.1 -0.9,-2 -2,-2zM18,14L6,14v-2h12v2zM18,11L6,11L6,9h12v2zM18,8L6,8L6,6h12v2z"/>
</vector>
+10
View File
@@ -0,0 +1,10 @@
# Default ignored files
/shelf/
/workspace.xml
# Ignored default folder with query files
/queries/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml
# Editor-based HTTP Client requests
/httpRequests/
+12
View File
@@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="EMPTY_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/tests" isTestSource="true" />
<excludeFolder url="file://$MODULE_DIR$/target" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>
+17
View File
@@ -0,0 +1,17 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
<data-source source="LOCAL" name="chatapp_dev@100.118.108.58" uuid="b14acf5d-6750-469b-8aea-59c8343eb11c">
<driver-ref>postgresql</driver-ref>
<synchronize>true</synchronize>
<jdbc-driver>org.postgresql.Driver</jdbc-driver>
<jdbc-url>jdbc:postgresql://100.118.108.58:5432/chatapp_dev</jdbc-url>
<jdbc-additional-properties>
<property name="com.intellij.clouds.kubernetes.db.host.port" />
<property name="com.intellij.clouds.kubernetes.db.enabled" value="false" />
<property name="com.intellij.clouds.kubernetes.db.container.port" />
</jdbc-additional-properties>
<working-dir>$ProjectFileDir$</working-dir>
</data-source>
</component>
</project>
+7
View File
@@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DataSourcePerFileMappings">
<file url="file://$PROJECT_DIR$/sql/schema.sql" value="b14acf5d-6750-469b-8aea-59c8343eb11c" />
<file url="file://$PROJECT_DIR$/src/repo/user_repo.rs" value="b14acf5d-6750-469b-8aea-59c8343eb11c" />
</component>
</project>
+6
View File
@@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
</profile>
</component>
+8
View File
@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/backend.iml" filepath="$PROJECT_DIR$/.idea/backend.iml" />
</modules>
</component>
</project>
+7
View File
@@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="SqlDialectMappings">
<file url="file://$PROJECT_DIR$/sql/schema.sql" dialect="PostgreSQL" />
<file url="PROJECT" dialect="PostgreSQL" />
</component>
</project>
+6
View File
@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$/.." vcs="Git" />
</component>
</project>
+5 -2
View File
@@ -10,7 +10,7 @@ dotenv = "0.15.0"
futures-util = "0.3.31" futures-util = "0.3.31"
image = "0.25.8" image = "0.25.8"
jsonwebtoken = { version = "10.3.0", features = ["rust_crypto"] } jsonwebtoken = { version = "10.3.0", features = ["rust_crypto"] }
rand = "0.9.2" rand = "0.8"
redis = { version = "0.25.4", features = ["tokio-comp"] } redis = { version = "0.25.4", features = ["tokio-comp"] }
reqwest = { version = "0.12.23", features = ["json"] } reqwest = { version = "0.12.23", features = ["json"] }
rocket = { version = "0.5.1", features = ["json", "secrets"] } rocket = { version = "0.5.1", features = ["json", "secrets"] }
@@ -20,8 +20,11 @@ rocket_dyn_templates = { version = "0.2.0", features = ["tera"] }
serde = { version = "1.0.228", features = ["derive"] } serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0.145" serde_json = "1.0.145"
sha2 = "0.10.9" sha2 = "0.10.9"
sqlx = { version = "0.7.4", features = ["macros", "time"] } sqlx = { version = "0.7.4", features = ["chrono", "macros", "postgres", "time"] }
tokio = { version = "1.47.1", features = ["full"] } tokio = { version = "1.47.1", features = ["full"] }
totp-rs = { version = "5.7.0", features = ["gen_secret", "qr", "rand"] } totp-rs = { version = "5.7.0", features = ["gen_secret", "qr", "rand"] }
tracing = "0.1.44" tracing = "0.1.44"
uuid = { version = "1.18.1", features = ["v4"] } uuid = { version = "1.18.1", features = ["v4"] }
thiserror = "1.0.69"
utoipa = { version = "5.4.0", features = ["rocket_extras", "chrono"] }
clap = { version = "4.5", features = ["derive"] }
@@ -1,49 +0,0 @@
-- Add migration script here
CREATE TABLE users (
id SERIAL PRIMARY KEY,
username VARCHAR(50) UNIQUE NOT NULL,
password VARCHAR(50) NOT NULL,
display_name VARCHAR(50),
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE channels (
id SERIAL PRIMARY KEY,
name VARCHAR(50) NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE messages (
id SERIAL PRIMARY KEY,
channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
content TEXT NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
is_edited BOOLEAN DEFAULT FALSE
);
create table attachments (
id SERIAL PRIMARY KEY,
message_id INTEGER NOT NULL REFERENCES messages(id) ON DELETE CASCADE,
path TEXT NOT NULL
);
CREATE INDEX idx_messages_channel_id ON messages (channel_id, id DESC);
CREATE INDEX idx_new_messages ON messages(created_at DESC);
-- Create a function to update the updated_at timestamp
CREATE OR REPLACE FUNCTION update_updated_at_column()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = CURRENT_TIMESTAMP;
RETURN NEW;
END;
$$ language 'plpgsql';
-- Create trigger for messages table
CREATE TRIGGER update_messages_updated_at
BEFORE UPDATE ON messages
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
@@ -1,20 +0,0 @@
-- Add migration script here
CREATE TABLE sessions (
id SERIAL PRIMARY KEY,
user_id INTEGER NOT NULL REFERENCES users(id),
token TEXT NOT NULL UNIQUE,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + INTERVAL '7 days'
);
CREATE OR REPLACE FUNCTION cleanup_expired_sessions()
RETURNS TRIGGER AS $$
BEGIN
DELETE FROM sessions WHERE expires_at < NOW();
RETURN NULL;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER trigger_cleanup_sessions
AFTER INSERT ON sessions
EXECUTE FUNCTION cleanup_expired_sessions();
@@ -1,5 +0,0 @@
-- Add migration script here
ALTER TABLE users ADD COLUMN email VARCHAR(100);
ALTER TABLE users ADD COLUMN twofa_enabled BOOLEAN DEFAULT FALSE;
ALTER TABLE users ADD COLUMN totp_secret VARCHAR(64);
ALTER TABLE users ADD COLUMN updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP;
@@ -1,2 +0,0 @@
-- Add migration script here
ALTER TABLE users ALTER COLUMN twofa_enabled SET NOT NULL;
@@ -1,18 +0,0 @@
-- Add migration script here
CREATE TABLE access_codes (
-- identifiers
id SERIAL PRIMARY KEY,
creator_id INTEGER NOT NULL REFERENCES users(id),
-- code data
code VARCHAR(255) NOT NULL,
name VARCHAR(255) NOT NULL,
-- uses
uses INTEGER NOT NULL DEFAULT 0,
max_uses INTEGER NOT NULL DEFAULT 1,
-- time data
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + INTERVAL '1 day'
);
@@ -1,10 +0,0 @@
-- Add migration script here
ALTER TABLE access_codes
ALTER COLUMN created_at
TYPE TIMESTAMP WITH TIME ZONE
USING created_at AT TIME ZONE 'UTC';
ALTER TABLE access_codes
ALTER COLUMN expires_at
TYPE TIMESTAMP WITH TIME ZONE
USING expires_at AT TIME ZONE 'UTC';
@@ -1,6 +0,0 @@
-- Add migration script here
TRUNCATE TABLE users CASCADE;
ALTER TABLE users DROP COLUMN password;
ALTER TABLE users ADD COLUMN pass_hash VARCHAR(255) NOT NULL;
ALTER TABLE users ADD CONSTRAINT users_username_unique UNIQUE (username);
@@ -1,13 +0,0 @@
-- Add migration script here
CREATE TYPE status AS ENUM ('pending', 'accepted', 'blocked');
CREATE TABLE relationships (
id SERIAL PRIMARY KEY,
from_user INTEGER NOT NULL REFERENCES users(id),
to_user INTEGER NOT NULL REFERENCES users(id),
status status NOT NULL DEFAULT 'pending',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT no_self_relationship CHECK (from_user != to_user),
CONSTRAINT unique_relationship UNIQUE (from_user, to_user)
);
@@ -0,0 +1,183 @@
-- Add migration script here
-- Add migration script here
CREATE TYPE user_role AS ENUM ('user', 'admin');
CREATE TYPE totp_status AS ENUM ('disabled', 'pending', 'enabled');
CREATE TABLE users (
id BIGSERIAL PRIMARY KEY NOT NULL,
role user_role NOT NULL DEFAULT 'user',
-- profile
nickname VARCHAR(255),
-- basic auth
username VARCHAR(255) UNIQUE NOT NULL,
passhash VARCHAR(255) NOT NULL,
-- email
email VARCHAR(255),
email_verified BOOLEAN DEFAULT FALSE,
-- 2fa
totp_secret VARCHAR(255),
totp_status totp_status NOT NULL DEFAULT 'disabled',
-- update tracking
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
deleted_at TIMESTAMPTZ
);
CREATE TABLE access_tokens (
id BIGSERIAL PRIMARY KEY NOT NULL,
creator_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
code VARCHAR(255) NOT NULL,
name VARCHAR(255) NOT NULL,
uses INTEGER NOT NULL DEFAULT 0,
max_uses INTEGER NOT NULL DEFAULT 1,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
expires_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + INTERVAL '24 hours',
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE refresh_tokens (
id BIGSERIAL PRIMARY KEY NOT NULL,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
token_hash VARCHAR(255) NOT NULL,
revoked BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
expires_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + INTERVAL '7 days'
);
CREATE TABLE spaces (
id BIGSERIAL PRIMARY KEY NOT NULL,
name VARCHAR(255) NOT NULL,
description TEXT,
owner_id BIGINT NOT NULL REFERENCES users(id),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE channels (
id BIGSERIAL PRIMARY KEY NOT NULL,
name VARCHAR(255) NOT NULL,
description TEXT,
space_id BIGINT NOT NULL REFERENCES spaces(id) ON DELETE CASCADE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE space_members (
space_id BIGINT NOT NULL REFERENCES spaces(id) ON DELETE CASCADE,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
joined_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
role user_role DEFAULT 'user',
PRIMARY KEY (space_id, user_id)
);
CREATE TABLE messages (
id BIGSERIAL PRIMARY KEY NOT NULL,
channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
content TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
is_edited BOOLEAN DEFAULT FALSE
);
CREATE TABLE attachments (
id BIGSERIAL PRIMARY KEY NOT NULL,
message_id BIGINT NOT NULL REFERENCES messages(id) ON DELETE CASCADE,
filename VARCHAR(255) NOT NULL,
content_type VARCHAR(100) NOT NULL, -- mime type e.g. image/png, video/mp4
size_bytes BIGINT NOT NULL,
url TEXT NOT NULL, -- path to file on your CDN/storage
width INTEGER, -- null for non-image/video
height INTEGER, -- null for non-image/video
duration_ms INTEGER, -- null for non-audio/video
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TYPE relationship_status AS ENUM ('pending', 'accepted', 'blocked');
CREATE TABLE relationships (
id BIGSERIAL PRIMARY KEY NOT NULL,
from_user BIGINT NOT NULL REFERENCES users(id),
to_user BIGINT NOT NULL REFERENCES users(id),
status relationship_status NOT NULL DEFAULT 'pending',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT no_self_relationship CHECK (from_user != to_user),
CONSTRAINT unique_relationship UNIQUE (from_user, to_user)
);
CREATE INDEX idx_messages_channel_id ON messages(channel_id, created_at DESC);
CREATE INDEX idx_messages_user_id ON messages(user_id);
CREATE INDEX idx_attachments_message ON attachments(message_id);
CREATE INDEX idx_channels_space_id ON channels(space_id);
CREATE INDEX idx_space_members_user ON space_members(user_id);
CREATE INDEX idx_refresh_tokens_hash ON refresh_tokens(token_hash);
CREATE INDEX idx_relationships_from ON relationships(from_user, to_user);
CREATE INDEX idx_relationships_to ON relationships(to_user);
CREATE INDEX idx_access_tokens_code ON access_tokens(code);
CREATE OR REPLACE FUNCTION update_updated_at()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = NOW();
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER users_updated_at
BEFORE UPDATE ON users
FOR EACH ROW EXECUTE FUNCTION update_updated_at();
CREATE TRIGGER spaces_updated_at
BEFORE UPDATE ON spaces
FOR EACH ROW EXECUTE FUNCTION update_updated_at();
CREATE TRIGGER channels_updated_at
BEFORE UPDATE ON channels
FOR EACH ROW EXECUTE FUNCTION update_updated_at();
CREATE TRIGGER messages_updated_at
BEFORE UPDATE ON messages
FOR EACH ROW EXECUTE FUNCTION update_updated_at();
CREATE TRIGGER relationships_updated_at
BEFORE UPDATE ON relationships
FOR EACH ROW EXECUTE FUNCTION update_updated_at();
CREATE TRIGGER access_tokens_updated_at
BEFORE UPDATE ON access_tokens
FOR EACH ROW EXECUTE FUNCTION update_updated_at();
CREATE OR REPLACE FUNCTION add_owner_to_space()
RETURNS TRIGGER AS $$
BEGIN
INSERT INTO space_members (space_id, user_id, role, joined_at)
VALUES (NEW.id, NEW.owner_id, 'admin', NOW());
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER space_owner_becomes_member
AFTER INSERT ON spaces
FOR EACH ROW EXECUTE FUNCTION add_owner_to_space();
@@ -1,21 +1,46 @@
use std::{ use crate::error::ApiResult;
sync::LazyLock, use crate::model::auth::{AccessTokenForm, AuthResponse, LoginCredentials, SignupCredentials};
time::{SystemTime, UNIX_EPOCH}, use crate::svc::auth_svc::AuthService;
}; use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use rocket::http::Status;
use rocket::request::{FromRequest, Outcome};
use rocket::serde::json::Json;
use rocket::serde::{Deserialize, Serialize};
use rocket::{Request, State};
use std::sync::LazyLock;
use std::time::{SystemTime, UNIX_EPOCH};
use crate::svc::access_token_svc::AccessTokenService;
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation, decode, encode}; #[post("/signup", data = "<cred>")]
use rand::Rng; pub async fn signup(
use rocket::{ cred: Json<SignupCredentials>,
Request, svc: &State<AuthService>
http::Status, ) -> ApiResult<Json<AuthResponse>> {
request::{self, FromRequest, Outcome}, let response = svc
}; .signup(
use rocket_db_pools::Connection; &cred.email, &cred.username, &cred.password, &cred.access_token,
use serde::{Deserialize, Serialize}; ).await?;
use sha2::{Digest, Sha256, digest::block_buffer::Lazy}; Ok(Json(response))
use sqlx::postgres::PgQueryResult; }
use crate::db::Postgres; #[post("/login", data = "<cred>")]
pub async fn login(
cred: Json<LoginCredentials>,
svc: &State<AuthService>
) -> ApiResult<Json<AuthResponse>> {
Ok(Json(svc.login(&cred.username, &cred.password).await?))
}
#[post("/invite", data = "<form>")]
pub async fn generate_invite(
session: Session,
form: Json<AccessTokenForm>,
svc: &State<AccessTokenService>
) -> ApiResult<String> {
svc.create(
session.uid, &form.name, form.max_uses,
form.start_date, form.expiry_date).await
}
static JWT_SECRET: LazyLock<String> = LazyLock::new(|| std::env::var("JWT_SECRET").unwrap()); static JWT_SECRET: LazyLock<String> = LazyLock::new(|| std::env::var("JWT_SECRET").unwrap());
@@ -27,7 +52,7 @@ pub enum TokenScope {
} }
pub struct Session { pub struct Session {
pub user_id: usize, pub uid: i64,
} }
#[rocket::async_trait] #[rocket::async_trait]
@@ -37,7 +62,7 @@ impl<'r> FromRequest<'r> for Session {
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> { async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match Claims::from_request(req).await { match Claims::from_request(req).await {
Outcome::Success(user) if user.scope == TokenScope::Full => Outcome::Success(Session { Outcome::Success(user) if user.scope == TokenScope::Full => Outcome::Success(Session {
user_id: user.sub as usize, uid: user.sub as i64,
}), }),
Outcome::Success(_) => { Outcome::Success(_) => {
eprintln!("warning: user with scope other than Full attempted to access session"); eprintln!("warning: user with scope other than Full attempted to access session");
@@ -106,4 +131,4 @@ impl<'r> FromRequest<'r> for Claims {
} }
} }
} }
} }
@@ -26,7 +26,7 @@ pub async fn profile_pic(user_id: usize) -> Option<NamedFile> {
Some(image) Some(image)
} else { } else {
Some( Some(
NamedFile::open("./cdn/profiles/full/default.svg") NamedFile::open("../../cdn/profiles/full/default.svg")
.await .await
.ok()?, .ok()?,
) )
+70
View File
@@ -0,0 +1,70 @@
use crate::api::auth::Session;
use crate::error::ApiResult;
use crate::svc::chat_svc::ChatService;
use chrono::{DateTime, Utc};
use rocket::response::stream::Event;
use rocket::serde::json::Json;
use rocket::serde::{Deserialize, Serialize};
use rocket::{Shutdown, State, ___internal_EventStream as EventStream};
use sqlx::FromRow;
use tokio::select;
use tokio::sync::broadcast;
/// ---------- Rocket routes ----------
#[derive(Debug, Serialize, Deserialize, Clone, FromRow)]
pub struct ChatMsg {
pub display_name: Option<String>,
pub user_id: i64,
pub text: String,
pub timestamp: DateTime<Utc>,
}
#[post("/chat/<channel_id>", format = "json", data = "<msg>")]
pub async fn post_message(
msg: Json<ChatMsg>,
chat: &State<ChatService>,
session: Session,
channel_id: i64,
) -> ApiResult<()> {
chat.send(channel_id, session.uid, &msg.text, Utc::now()).await
}
#[get("/events/<channel_id>")]
pub async fn event_stream(
chat: &State<ChatService>,
s: Session,
mut shutdown: Shutdown,
channel_id: i64,
) -> ApiResult<EventStream![]> {
let messages = chat.get_messages(channel_id, 100)
.await?; // if get message returned err, inform user.
let mut rx = chat.subscribe(channel_id).await;
let id = s.uid;
Ok(EventStream! {
for msg in messages {
yield Event::json(&msg);
}
loop {
select!{
_ = &mut shutdown => break, // exit early on shutdown
msg = rx.recv() => match msg {
Ok(msg) => {
tracing::info!("yielding message!");
yield Event::json(&msg)
},
Err(broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!("Receiver lagging on channel {channel_id} by {n} events",);
yield Event::comment("RecvError::Lagged");
}
Err(broadcast::error::RecvError::Closed) => {
tracing::info!("Broadcaster hung up on channel {channel_id}!");
break
},
},
}
}
})
}
+7
View File
@@ -0,0 +1,7 @@
pub mod auth;
pub mod chat;
pub mod totp;
pub mod settings;
pub mod cdn;
pub mod profile;
pub mod space;
+13
View File
@@ -0,0 +1,13 @@
use rocket::State;
use crate::api::auth::Session;
use crate::error::ApiResult;
use crate::svc::user_svc::UserService;
#[get("/users/<id>")]
pub async fn display_name(
id: i64,
_ag: Session,
svc: &State<UserService>,
) -> ApiResult<String> {
svc.get_username(id).await
}
+68
View File
@@ -0,0 +1,68 @@
use crate::api::auth::Session;
use crate::error::ApiResult;
use crate::svc::settings_svc::SettingsService;
use rocket::serde::json::Json;
use rocket::serde::{Deserialize, Serialize};
use rocket::State;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PasswordForm {
pub old_password: String,
pub new_password: String,
}
#[post("/settings/password", data = "<form>")]
pub async fn change_password(
session: Session,
form: Json<PasswordForm>,
settings: &State<SettingsService>
) -> ApiResult<()> {
settings.change_password(
session.uid, &form.old_password, &form.new_password
).await
}
#[derive(Deserialize, Debug, Clone)]
pub struct DisplayNameForm {
pub display_name: Option<String>,
}
#[derive(Deserialize, Debug, Clone)]
pub struct PasswordAnd2faForm {
pub password: String,
pub totp_code: Option<String>,
}
#[delete("/settings", data = "<data>")]
pub async fn delete_account(
session: Session,
data: Json<PasswordAnd2faForm>,
settings: &State<SettingsService>
) -> ApiResult<()> {
settings.delete_account(
session.uid, &data.password, &data.totp_code
).await
}
#[patch("/settings/display_name", data = "<new>")]
pub async fn change_display_name(
session: Session,
new: Json<DisplayNameForm>,
settings: &State<SettingsService>
) -> ApiResult<()> {
settings.change_display_name(session.uid, new.display_name.clone()).await
}
#[derive(Deserialize)]
pub struct UsernameForm {
pub username: String,
}
#[patch("/settings/username", data = "<new>")]
pub async fn change_username(
session: Session,
new: Json<UsernameForm>,
settings: &State<SettingsService>
) -> ApiResult<()> {
settings.change_username(session.uid, &new.username).await
}
+36
View File
@@ -0,0 +1,36 @@
use crate::error::ApiResult;
use crate::model::space::{Space, SpaceDto};
use crate::model::space::Channel;
use crate::repo::{SpaceRepo, ChannelRepo};
use rocket::serde::json::Json;
use rocket::State;
use std::sync::Arc;
use crate::api::auth::Session;
use crate::svc::chat_svc::ChatService;
#[get("/spaces")]
pub async fn list_spaces(
space_repo: &State<Arc<dyn SpaceRepo>>
) -> ApiResult<Json<Vec<Space>>> {
let spaces = space_repo.get_all().await?;
Ok(Json(spaces))
}
#[get("/spaces/<space_id>/channels")]
pub async fn list_channels(
space_id: i64,
channel_repo: &State<Arc<dyn ChannelRepo>>
) -> ApiResult<Json<Vec<Channel>>> {
let channels = channel_repo.get_by_space_id(space_id).await?;
Ok(Json(channels))
}
#[get("/accessible_channels")]
pub async fn get_accessible_channels(
session: Session,
svc: &State<ChatService>
) -> ApiResult<Json<Vec<SpaceDto>>> {
let space = svc.get_accessible_channels(session.uid).await?;
println!("{:?}", space);
Ok(Json(space))
}
+120
View File
@@ -0,0 +1,120 @@
use crate::api::auth::{Claims, Session, TokenScope};
use crate::error::{ApiResult, AppError};
use crate::model::auth::AuthResponse;
use crate::svc::auth_svc::AuthService;
use rocket::serde::json::Json;
use rocket::serde::{Deserialize, Serialize};
use rocket::State;
use totp_rs::{Algorithm, TOTP};
#[derive(Debug, Deserialize)]
pub struct TOTPSixDigitCode {
code: String,
}
#[derive(Debug, sqlx::Type, Clone, Copy, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
#[sqlx(type_name = "totp_status", rename_all = "lowercase")]
pub enum TotpStatus {
Enabled,
Pending,
Disabled,
}
#[derive(Serialize)]
pub struct QrResponse {
qr_code: String,
}
#[derive(Deserialize)]
pub struct TotpVerifyRequest {
pub code: String,
}
#[derive(Deserialize)]
pub struct PasswordConfirmation {
password: String,
}
#[derive(Deserialize)]
pub struct PasswordAnd2fa {
pub password: String,
pub totp_code: String,
}
pub fn totp_gen(user_id: i64, secret: &[u8]) -> ApiResult<TOTP> {
TOTP::new(
Algorithm::SHA1,
6,
1,
30,
secret.to_owned(),
Some("chat.zxq5.dev".to_string()),
format!("{}", user_id),
)
.map_err(|_| AppError::internal("failed to generate totp"))
}
#[post("/totp", data = "<form>")]
pub async fn confirm_totp(
user: Session,
form: Json<TOTPSixDigitCode>,
svc: &State<AuthService>,
) -> ApiResult<()> {
svc.confirm_totp(user.uid, &form.code).await
}
#[post("/totp.jpg", data = "<form>")]
pub async fn get_totp(
user: Session,
form: Json<PasswordConfirmation>,
svc: &State<AuthService>,
) -> ApiResult<Json<QrResponse>> {
let secret = svc.get_or_create_totp_secret(user.uid, &form.password).await?;
let qr_b64 = totp_gen(user.uid, secret.as_bytes())
.map_err(|_| AppError::internal("invalid totp secret"))?
.get_qr_base64()
.map_err(|_| AppError::internal("failed to generate qr code"))?;
Ok(Json(QrResponse {
qr_code: format!("data:image/png;base64,{}", qr_b64),
}))
}
#[get("/totp/status")]
pub async fn get_totp_status(
user: Session,
svc: &State<AuthService>,
) -> ApiResult<Json<TotpStatus>> {
Ok(Json(
svc.get_totp_status(user.uid).await?
.then_some(TotpStatus::Enabled)
.unwrap_or(TotpStatus::Disabled),
))
}
#[delete("/totp", data = "<form>")]
pub async fn disable_totp(
user: Session,
form: Json<PasswordAnd2fa>,
svc: &State<AuthService>,
) -> ApiResult<Json<AuthResponse>> {
let response = svc.disable_totp(user.uid, &form.password, &form.totp_code).await?;
Ok(Json(response))
}
#[post("/totp/verify", data = "<body>")]
pub async fn verify_totp(
claims: Claims,
body: Json<TotpVerifyRequest>,
svc: &State<AuthService>,
) -> ApiResult<Json<AuthResponse>> {
// reject if they somehow got here with a full token
if claims.scope != TokenScope::TotpPending {
return Err(AppError::Forbidden);
}
let response = svc.login_totp(claims.sub as i64, &body.code).await?;
Ok(Json(response))
}
-211
View File
@@ -1,211 +0,0 @@
use argon2::{
Argon2, PasswordHash, PasswordHasher, PasswordVerifier,
password_hash::{SaltString, rand_core::OsRng},
};
use jsonwebtoken::{EncodingKey, Header, encode};
use rocket::{
http::{CookieJar, Status},
response::{Redirect, status::BadRequest},
serde::json::Json,
time::OffsetDateTime,
};
use rocket_db_pools::Connection;
use rocket_dyn_templates::{Template, context};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::{
auth::session::{Claims, Session, TokenScope},
db::Postgres,
user::User,
};
#[derive(Serialize, Deserialize)]
pub struct SignupCredentials {
pub email: String,
pub username: String,
pub password: String,
pub access_token: String,
}
#[derive(Serialize, Deserialize)]
pub struct LoginCredentials {
pub username: String,
pub password: String,
}
#[derive(Serialize, Deserialize)]
pub struct AuthResponse {
pub token: String,
pub totp_required: bool,
}
#[get("/signup")]
pub async fn signup_page() -> Template {
Template::render("signup", context!())
}
#[post("/signup", data = "<cred>")]
pub async fn signup(
cred: Json<SignupCredentials>,
jar: &CookieJar<'_>,
mut db: Connection<Postgres>,
) -> Result<Json<AuthResponse>, Status> {
let token_id = AccessToken::validate(&cred.access_token, &mut db)
.await
.map_err(|_| Status::Unauthorized)?;
let salt = SaltString::generate(&mut OsRng);
let hashed = Argon2::default()
.hash_password(cred.password.as_bytes(), &salt)
.map_err(|_| Status::InternalServerError)?
.to_string();
let result = sqlx::query!(
"INSERT INTO users (email, username, pass_hash) VALUES ($1, $2, $3) RETURNING id",
cred.email,
cred.username,
hashed,
)
.fetch_one(&mut **db)
.await
.map_err(|_| Status::InternalServerError)?;
let jwt = Claims::new(result.id as usize, TokenScope::Full).encode();
token_id
.use_token(&mut db)
.await
.expect("unable to use access code");
Ok(Json(AuthResponse {
token: jwt,
totp_required: false,
}))
}
#[get("/login")]
pub async fn login_page() -> Template {
Template::render("login", context!())
}
#[post("/login", data = "<cred>")]
pub async fn login(
mut db: Connection<Postgres>,
cred: Json<LoginCredentials>,
) -> Result<Json<AuthResponse>, Status> {
println!("e");
let row = sqlx::query!(
"SELECT id, pass_hash, twofa_enabled FROM users WHERE username = $1",
cred.username,
)
.fetch_one(&mut **db)
.await
.map_err(|_| Status::Unauthorized)?;
println!("ok");
// verify password as before
let parsed_hash = PasswordHash::new(&row.pass_hash).map_err(|_| Status::InternalServerError)?;
Argon2::default()
.verify_password(cred.password.as_bytes(), &parsed_hash)
.map_err(|_| Status::Unauthorized)?;
println!("ok2");
let user_id = row.id as usize;
// issue either a partial or full token depending on 2FA status
let (session, totp_required) = if row.twofa_enabled {
(Claims::new(user_id, TokenScope::TotpPending), true)
} else {
(Claims::new(user_id, TokenScope::Full), false)
};
Ok(Json(AuthResponse {
token: session.encode(),
totp_required,
}))
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessTokenForm {
pub name: String,
pub max_uses: usize,
pub expiry_date: usize,
pub start_date: usize,
}
#[get("/invite")]
pub async fn invite_page(_s: Session) -> Template {
Template::render("invite", context! {})
}
#[post("/invite", data = "<form>")]
pub async fn generate_invite(
session: Session,
mut db: Connection<Postgres>,
form: Json<AccessTokenForm>,
) -> Result<String, Status> {
if form.start_date > form.expiry_date {
return Err(Status::BadRequest);
}
let code = Uuid::new_v4().to_string();
sqlx::query!(
"INSERT INTO access_codes (name, code, creator_id, max_uses, created_at, expires_at)
VALUES ($1, $2, $3, $4, $5, $6) RETURNING id",
form.name,
code,
session.user_id as i32,
form.max_uses as i32,
OffsetDateTime::from_unix_timestamp_nanos(form.start_date as i128 * 1_000_000).unwrap(),
OffsetDateTime::from_unix_timestamp_nanos(form.expiry_date as i128 * 1_000_000).unwrap()
)
.fetch_one(&mut **db)
.await
.map_err(|_| Status::InternalServerError)?;
Ok(code)
}
pub struct AccessToken {
id: i32,
_code: String,
}
impl AccessToken {
pub async fn validate(
token: &str,
db: &mut Connection<Postgres>,
) -> Result<AccessToken, String> {
match sqlx::query!(
"SELECT id FROM access_codes
WHERE code = $1
AND created_at < NOW()
AND expires_at > NOW()
AND uses < max_uses",
token
)
.fetch_one(&mut ***db)
.await
{
Ok(row) => Ok(AccessToken {
id: row.id,
_code: token.to_string(),
}),
Err(_) => Err(String::from("Invalid or Expired token!")),
}
}
pub async fn use_token(&self, db: &mut Connection<Postgres>) -> Result<(), String> {
sqlx::query!(
"UPDATE access_codes SET uses = uses + 1 WHERE id = $1",
self.id
)
.execute(&mut ***db)
.await
.map_err(|_| String::from("Invalid or Expired token!"))?;
Ok(())
}
}
-12
View File
@@ -1,12 +0,0 @@
pub mod account;
pub mod profile;
pub mod session;
pub mod two_factor;
pub use session::Session;
pub use account::{generate_invite, invite_page, login, login_page, signup, signup_page};
pub use profile::{change_display_name, change_password, change_username, delete_account};
pub use two_factor::{
confirm_totp, disable_totp, get_totp, get_totp_status, mfa_page, verify_totp,
};
-143
View File
@@ -1,143 +0,0 @@
use argon2::{
Argon2, PasswordHash, PasswordHasher, PasswordVerifier,
password_hash::{SaltString, rand_core::OsRng},
};
use rocket::{http::Status, serde::json::Json};
use rocket_db_pools::Connection;
use serde::{Deserialize, Serialize};
use crate::{auth::Session, db::Postgres, user::User};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PasswordForm {
old_password: String,
new_password: String,
}
#[post("/settings/password", data = "<form>")]
pub async fn change_password(
session: Session,
mut db: Connection<Postgres>,
form: Json<PasswordForm>,
) -> Result<(), Status> {
let mut user = User::get_by_id(session.user_id, &mut db)
.await
.ok_or(Status::NotFound)
.inspect_err(|_| {
tracing::error!(
"Valid session does not have a valid user. ID: {}",
session.user_id
)
})?;
user.verify_password(&form.old_password)?;
// old password is correct, so new one can be set.
let salt = SaltString::generate(&mut OsRng);
let hashed = Argon2::default()
.hash_password(form.new_password.as_bytes(), &salt)
.inspect_err(|e| tracing::error!("failed to hash password! {e}"))
.map_err(|_| Status::InternalServerError)?
.to_string();
user.set_pass_hash(hashed, &mut db)
.await
.inspect_err(|e| tracing::error!("{e}"))
.map_err(|_| Status::InternalServerError)?;
Ok(())
}
#[derive(Deserialize, Debug, Clone)]
pub struct DisplayNameForm {
pub display_name: Option<String>,
}
#[derive(Deserialize, Debug, Clone)]
pub struct PasswordAnd2fa {
pub password: String,
pub totp_code: Option<String>,
}
#[delete("/settings", data = "<data>")]
pub async fn delete_account(
session: Session,
mut db: Connection<Postgres>,
data: Json<PasswordAnd2fa>,
) -> Result<(), Status> {
let mut user = User::get_by_id(session.user_id, &mut db)
.await
.ok_or(Status::NotFound)
.inspect_err(|_| {
tracing::error!(
"Valid session does not have a valid user. ID: {}",
session.user_id
)
})?;
user.verify_password(&data.password)?;
if user.twofa_enabled {
user.verify_2fa(data.totp_code.as_deref().unwrap_or(""))?;
}
user.delete(&mut db)
.await
.inspect_err(|e| tracing::error!("{e}"))
.map_err(|_| Status::InternalServerError)?;
Ok(())
}
#[patch("/settings/display_name", data = "<new>")]
pub async fn change_display_name(
session: Session,
mut db: Connection<Postgres>,
new: Json<DisplayNameForm>,
) -> Result<(), Status> {
let mut user = User::get_by_id(session.user_id, &mut db)
.await
.ok_or(Status::NotFound)
.inspect_err(|_| {
tracing::error!(
"Valid session does not have a valid user. ID: {}",
session.user_id
)
})?;
user.set_display_name(new.display_name.clone(), &mut db)
.await
.inspect_err(|e| tracing::error!("{e}"))
.map_err(|_| Status::InternalServerError)?;
Ok(())
}
#[derive(Deserialize)]
pub struct UsernameForm {
username: String,
}
#[patch("/settings/username", data = "<new>")]
pub async fn change_username(
session: Session,
mut db: Connection<Postgres>,
new: Json<UsernameForm>,
) -> Result<(), Status> {
let mut user = User::get_by_id(session.user_id, &mut db)
.await
.ok_or(Status::NotFound)
.inspect_err(|_| {
tracing::error!(
"Valid session does not have a valid user. ID: {}",
session.user_id
)
})?;
user.set_username(new.username.clone(), &mut db)
.await
.inspect_err(|e| tracing::error!("{e}"))
.map_err(|_| Status::InternalServerError)?;
Ok(())
}
-301
View File
@@ -1,301 +0,0 @@
use futures_util::TryFutureExt;
use rocket::{
Request,
http::Status,
outcome::{Outcome, try_outcome},
request::{self, FromRequest},
response::status::{self},
serde::json::Json,
};
use rocket_db_pools::Connection;
use rocket_dyn_templates::{Template, context};
use serde::{Deserialize, Serialize};
use totp_rs::{Algorithm, Secret, TOTP};
use crate::{
auth::{
account::AuthResponse,
profile::PasswordAnd2fa,
session::{Claims, Session, TokenScope},
},
db::Postgres,
user::User,
};
// Utility methods
pub fn totp_gen(user_id: usize, secret: &[u8]) -> Result<TOTP, String> {
TOTP::new(
Algorithm::SHA1,
6,
1,
30,
secret.to_owned(),
Some("chat.zxq5.dev".to_string()),
format!("{}", user_id),
)
.map_err(|_| String::from("Invalid Secret"))
}
// pages
#[get("/totp")]
pub async fn mfa_page(_session: Session) -> Template {
Template::render("2fa", context!())
}
#[post("/totp", data = "<form>")]
pub async fn confirm_totp(
mfa: TOTPSecret,
form: Json<TOTPSixDigitCode>,
mut db: Connection<Postgres>,
) -> Result<(), status::Custom<&'static str>> {
if form.code.len() != 6 || form.code.parse::<u32>().is_err() {
return Err(status::Custom(Status::BadRequest, "Invalid 6-digit code"));
}
println!("valid");
let totp = totp_gen(mfa.user_id, mfa.secret.as_bytes())
.map_err(|_| status::Custom(Status::InternalServerError, "TOTP Error"))?;
if !totp.check_current(&form.code).unwrap_or(false) {
return Err(status::Custom(Status::BadRequest, "Incorrect code"));
}
println!("correct");
if sqlx::query!(
"UPDATE users SET twofa_enabled = true WHERE id = $1",
mfa.user_id as i32
)
.execute(&mut **db)
.await
.is_err()
{
return Err(status::Custom(
Status::InternalServerError,
"unable to enable 2fa",
));
};
println!("enabled");
Ok(())
}
#[derive(Deserialize)]
pub struct PasswordConfirmation {
password: String,
}
#[post("/totp.jpg", data = "<form>")]
pub async fn get_totp(
mfa: TOTPSecret,
form: Json<PasswordConfirmation>,
) -> Option<Json<QrResponse>> {
let qr_b64 = totp_gen(mfa.user_id, mfa.secret.as_bytes())
.expect("Invalid TOTP")
.get_qr_base64()
.unwrap();
Some(Json(QrResponse {
qr_code: format!("data:image/png;base64,{}", qr_b64),
}))
}
#[derive(Debug, Deserialize)]
pub struct TOTPSixDigitCode {
code: String,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TotpStatus {
Enabled,
Disabled,
}
pub struct TOTPSecret {
user_id: usize,
secret: String,
}
#[derive(Serialize)]
pub struct QrResponse {
qr_code: String,
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for TOTPSecret {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
let auth_header = request.headers().get_one("Authorization");
println!(
"TOTPSecret guard - Auth header present: {}",
auth_header.is_some()
);
let user = try_outcome!(request.guard::<Claims>().await);
println!(
"TOTPSecret guard - Claims ok, user: {}, scope: {:?}",
user.sub, user.scope
);
// only allow full tokens for TOTP setup
if user.scope != TokenScope::Full {
println!("TOTPSecret guard - rejected, scope is {:?}", user.scope);
return Outcome::Error((Status::Forbidden, ()));
}
let user = try_outcome!(request.guard::<Session>().await);
let mut pool = match request.guard::<Connection<Postgres>>().await {
Outcome::Success(pool) => pool,
_ => return Outcome::Error((Status::Unauthorized, ())),
};
let row = sqlx::query!(
"SELECT twofa_enabled, totp_secret FROM users WHERE id = $1",
user.user_id as i32
)
.fetch_one(&mut **pool)
.await;
let (enabled, mut secret) = match row {
Ok(r) => (r.twofa_enabled, r.totp_secret),
Err(_) => return Outcome::Error((Status::Unauthorized, ())),
};
if secret.is_none() {
let new_secret = Secret::generate_secret().to_encoded().to_string();
sqlx::query!(
"UPDATE users SET totp_secret = $1 WHERE id = $2",
new_secret,
user.user_id as i32
)
.execute(&mut **pool)
.await
.ok();
secret = Some(new_secret);
}
Outcome::Success(TOTPSecret {
user_id: user.user_id,
secret: secret.unwrap(),
})
}
}
impl TOTPSecret {
pub async fn enable(&self, db: &mut Connection<Postgres>) -> Result<(), ()> {
match sqlx::query!(
"UPDATE users SET twofa_enabled = true WHERE id = $1",
self.user_id as i32,
)
.execute(&mut ***db)
.await
{
Ok(_) => Ok(()),
Err(_) => Err(()),
}
}
}
#[derive(Deserialize)]
pub struct TotpVerifyRequest {
pub code: String,
}
#[get("/totp/status")]
pub async fn get_totp_status(
user: Session,
mut db: Connection<Postgres>,
) -> Result<Json<TotpStatus>, Status> {
Ok(Json(
if sqlx::query!(
"SELECT twofa_enabled FROM users WHERE id = $1",
user.user_id as i32,
)
.fetch_one(&mut **db)
.await
.map_err(|_| Status::NotFound)?
.twofa_enabled
{
TotpStatus::Enabled
} else {
TotpStatus::Disabled
},
))
}
#[delete("/totp", data = "<form>")]
pub async fn disable_totp(
user: Session,
mut db: Connection<Postgres>,
form: Json<PasswordAnd2fa>,
) -> Result<Json<AuthResponse>, Status> {
let totp_code = form.totp_code.clone().ok_or(Status::BadRequest)?;
let mut user = User::get_by_id(user.user_id, &mut db)
.await
.ok_or(Status::NotFound)?;
user.verify_password(&form.password)?;
user.verify_2fa(&totp_code)?;
user.set_twofa_enabled(false, &mut db)
.await
.map_err(|_| Status::InternalServerError)?;
Ok(Json(AuthResponse {
token: Claims::new(user.id as usize, TokenScope::Full).encode(),
totp_required: false,
}))
}
#[post("/totp/verify", data = "<body>")]
pub async fn verify_totp(
claims: Claims, // request guard checks token validity
mut db: Connection<Postgres>,
body: Json<TotpVerifyRequest>,
) -> Result<Json<AuthResponse>, Status> {
println!("reached 1");
// reject if they somehow got here with a full token
if claims.scope != TokenScope::TotpPending {
return Err(Status::Forbidden);
}
println!("reached 2");
let row = sqlx::query!(
"SELECT totp_secret FROM users WHERE id = $1 AND twofa_enabled = TRUE",
claims.sub
)
.fetch_one(&mut **db)
.await
.map_err(|_| Status::Unauthorized)?;
println!("reached 3");
let totp = totp_gen(
claims.sub as usize,
row.totp_secret
.expect("user with 2fa enabled has no totp secret")
.as_bytes(),
)
.map_err(|_| Status::InternalServerError)?;
if !totp
.check_current(&body.code)
.map_err(|_| Status::InternalServerError)?
{
return Err(Status::Unauthorized);
}
println!("reached 5");
let claims = Claims::new(claims.sub as usize, TokenScope::Full);
Ok(Json(AuthResponse {
token: claims.encode(),
totp_required: false,
}))
}
+96
View File
@@ -0,0 +1,96 @@
use clap::{Parser, Subcommand};
use sqlx::postgres::PgPoolOptions;
use std::time::Duration;
use std::sync::Arc;
use crate::repo::user_repo::UserRepository;
use crate::repo::space_repo::SpaceRepository;
use crate::repo::channel_repo::ChannelRepository;
use crate::repo::{UserRepo, SpaceRepo, ChannelRepo};
use argon2::{
password_hash::{PasswordHasher, SaltString},
Argon2,
};
use rand::rngs::OsRng;
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
pub struct Cli {
#[command(subcommand)]
pub command: Option<Commands>,
}
#[derive(Subcommand)]
pub enum Commands {
/// First-time setup for the server
Setup {
/// Admin username
#[arg(short, long)]
username: String,
/// Admin password
#[arg(short, long)]
password: String,
/// Default space name
#[arg(short, long, default_value = "Default Space")]
space: String,
/// Default channel name
#[arg(short, long, default_value = "general")]
channel: String,
},
}
pub async fn handle_cli() -> bool {
let cli = Cli::parse();
match cli.command {
Some(Commands::Setup { username, password, space, channel }) => {
if let Err(e) = run_setup(username, password, space, channel).await {
eprintln!("Setup failed: {}", e);
std::process::exit(1);
}
println!("Setup completed successfully!");
true
}
None => false,
}
}
async fn run_setup(username: String, password: String, space_name: String, channel_name: String) -> Result<(), Box<dyn std::error::Error>> {
dotenv::dotenv().ok();
let db_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
let pool = PgPoolOptions::new()
.max_connections(1)
.acquire_timeout(Duration::from_secs(5))
.connect(&db_url)
.await?;
let user_repo = UserRepository::new(pool.clone());
let space_repo = SpaceRepository::new(pool.clone());
let channel_repo = ChannelRepository::new(pool.clone());
// 1. Create admin user
println!("Creating admin user: {}...", username);
let argon2 = Argon2::default();
let salt = SaltString::generate(&mut OsRng);
let passhash = argon2
.hash_password(password.as_bytes(), &salt)
.map_err(|e| e.to_string())?
.to_string();
let user_id = user_repo.new_user("admin@localhost", &username, &passhash).await?;
user_repo.set_role(user_id, "admin").await?;
// 2. Create default space
println!("Creating default space: {}...", space_name);
let space_id = space_repo.create(&space_name, Some("Default space created during setup"), user_id).await?;
// 3. Create default channel
println!("Creating default channel: {}...", channel_name);
channel_repo.create(&channel_name, Some("Default channel"), space_id).await?;
Ok(())
}
-9
View File
@@ -1,9 +0,0 @@
use rocket_db_pools::{Database, deadpool_redis};
#[derive(Database)]
#[database("postgres_db")]
pub struct Postgres(sqlx::PgPool);
#[derive(Database)]
#[database("redis_cache")]
pub struct Redis(deadpool_redis::Pool);
+127
View File
@@ -0,0 +1,127 @@
// error.rs
use rocket::{http::Status, response::{self, Responder}, Request, Response};
use thiserror::Error;
use rocket_dyn_templates::Template;
use rocket::serde::Serialize;
#[derive(Error, Debug)]
pub enum AppError {
#[error("Not found")]
NotFound,
#[error("Unauthorized")]
Unauthorised(String),
#[error("Forbidden")]
Forbidden,
#[error("Bad request: {0}")]
BadRequest(String),
#[error("Database error: {0}")]
Database(#[from] sqlx::Error),
#[error("Internal error: {0}")]
Internal(String),
}
impl AppError {
pub fn internal(msg: impl Into<String>) -> Self {
Self::Internal(msg.into())
}
pub fn bad_request(msg: impl Into<String>) -> Self {
Self::BadRequest(msg.into())
}
pub fn unauthorised(msg: impl Into<String>) -> Self {
Self::Unauthorised(msg.into())
}
}
impl<'r> Responder<'r, 'static> for AppError {
fn respond_to(self, _req: &'r Request<'_>) -> response::Result<'static> {
let status = match &self {
AppError::NotFound => Status::NotFound,
AppError::Unauthorised(_) => Status::Unauthorized,
AppError::Forbidden => Status::Forbidden,
AppError::BadRequest(_) => Status::BadRequest,
AppError::Database(_) => Status::InternalServerError,
AppError::Internal(_) => Status::InternalServerError,
};
// log internal errors
if status == Status::InternalServerError {
tracing::error!("Internal Server Error: {}", self);
}
Response::build()
.status(status)
.header(rocket::http::ContentType::Plain)
.sized_body(
self.to_string().len(),
std::io::Cursor::new(self.to_string())
)
.ok()
}
}
pub type ApiResult<T> = Result<T, AppError>;
#[derive(Serialize)]
struct ErrorContext {
error_code: u16,
error_message: &'static str,
additional_info: &'static str,
redirect: Option<RedirectContext>,
}
#[derive(Serialize)]
struct RedirectContext {
url: &'static str,
message: &'static str,
}
#[catch(404)]
pub async fn handle_404() -> Template {
Template::render(
"error",
ErrorContext {
error_code: 404,
error_message: "Not Found",
additional_info: "There's nothing here.",
redirect: Some(RedirectContext {
url: "/",
message: "Home",
}),
},
)
}
#[catch(401)]
pub async fn handle_401() -> Template {
Template::render(
"error",
ErrorContext {
error_code: 401,
error_message: "Unauthorised",
additional_info: "You are not authorised to access this resource.",
redirect: Some(RedirectContext {
url: "/login",
message: "Login",
}),
},
)
}
#[catch(default)]
pub async fn handle_default(status: Status, _request: &Request<'_>) -> Template {
Template::render(
"error",
ErrorContext {
error_code: status.code,
error_message: "Unknown Error",
additional_info: "I don't know what to do with this error.",
redirect: None,
},
)
}
-62
View File
@@ -1,62 +0,0 @@
use rocket::{Request, http::Status};
use rocket_dyn_templates::Template;
use serde::Serialize;
#[derive(Serialize)]
struct ErrorContext {
error_code: u16,
error_message: &'static str,
additional_info: &'static str,
redirect: Option<RedirectContext>,
}
#[derive(Serialize)]
struct RedirectContext {
url: &'static str,
message: &'static str,
}
#[catch(404)]
pub async fn handle_404() -> Template {
Template::render(
"error",
ErrorContext {
error_code: 404,
error_message: "Not Found",
additional_info: "There's nothing here.",
redirect: Some(RedirectContext {
url: "/",
message: "Home",
}),
},
)
}
#[catch(401)]
pub async fn handle_401() -> Template {
Template::render(
"error",
ErrorContext {
error_code: 401,
error_message: "Unauthorised",
additional_info: "You are not authorised to access this resource.",
redirect: Some(RedirectContext {
url: "/login",
message: "Login",
}),
},
)
}
#[catch(default)]
pub async fn handle_default(status: Status, _request: &Request<'_>) -> Template {
Template::render(
"error",
ErrorContext {
error_code: status.code,
error_message: "Unknown Error",
additional_info: "I don't know what to do with this error.",
redirect: None,
},
)
}
+149
View File
@@ -0,0 +1,149 @@
#![deny(clippy::unwrap_used)]
#![warn(clippy::all, clippy::nursery, clippy::cargo, clippy::pedantic)]
#[macro_use]
extern crate rocket;
pub mod messenger;
pub mod api;
pub mod repo;
pub mod error;
pub mod svc;
pub mod model;
pub mod cli;
use crate::repo::{access_token_repo::AccessTokenRepo, Repo};
use crate::repo::message_repo::MessageRepository;
use crate::repo::user_repo::UserRepository;
use crate::repo::space_repo::SpaceRepository;
use crate::repo::channel_repo::ChannelRepository;
use crate::svc::auth_svc::AuthService;
use crate::svc::chat_svc::ChatService;
use crate::svc::settings_svc::SettingsService;
use crate::svc::user_svc::UserService;
use rocket::fs::{FileServer, NamedFile};
use rocket::http::Method;
use rocket_cors::{AllowedOrigins, CorsOptions};
use rocket_dyn_templates::Template;
use sqlx::postgres::PgPoolOptions;
use std::env;
use std::sync::{Arc, LazyLock};
use std::time::Duration;
use api::cdn;
use crate::svc::access_token_svc::AccessTokenService;
use crate::svc::llm_service::LlmService;
pub fn rocket() -> rocket::Rocket<rocket::Build> {
if std::env::var("RELEASE_MODE").unwrap_or_default() != "1" {
dotenv::dotenv().ok();
}
let db_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
let pool = PgPoolOptions::new()
.max_connections(25)
.min_connections(5)
.acquire_timeout(Duration::from_secs(5))
.connect_lazy(&db_url)
.expect("Failed to create database pool");
let user_repo = Arc::new(UserRepository::new(pool.clone()));
let message_repo = MessageRepository::new(pool.clone());
let token_repo = Arc::new(AccessTokenRepo::new(pool.clone()));
let space_repo: Arc<dyn repo::SpaceRepo> = Arc::new(SpaceRepository::new(pool.clone()));
let channel_repo: Arc<dyn repo::ChannelRepo> = Arc::new(ChannelRepository::new(pool.clone()));
let llm_service = LlmService::new();
let chat_service = ChatService::new(32, llm_service.clone(), message_repo.clone(), user_repo.clone(), channel_repo.clone(), space_repo.clone());
rocket_builder(user_repo, token_repo, space_repo, channel_repo, chat_service)
}
pub fn rocket_builder(
user_repo: Arc<dyn repo::UserRepo>,
token_repo: Arc<dyn repo::AccessTokenRepoTrait>,
space_repo: Arc<dyn repo::SpaceRepo>,
channel_repo: Arc<dyn repo::ChannelRepo>,
chat_service: ChatService
) -> rocket::Rocket<rocket::Build> {
let cors = CorsOptions::default()
.allowed_origins(AllowedOrigins::all())
.allowed_methods(
vec![Method::Get, Method::Post, Method::Patch]
.into_iter()
.map(From::from)
.collect(),
)
.allow_credentials(true);
let access_token_svc = AccessTokenService::new(token_repo.clone());
let auth_service = AuthService::new(user_repo.clone(), access_token_svc.clone());
let settings_service = SettingsService::new(auth_service.clone(), user_repo.clone());
let user_service = UserService::new(user_repo.clone());
rocket::build()
.manage(chat_service)
.manage(auth_service)
.manage(settings_service)
.manage(user_service)
.manage(space_repo)
.manage(channel_repo)
.attach(cors.to_cors().unwrap())
.attach(Template::fairing())
.mount("/static", FileServer::from("static"))
.mount("/cdn", cdn::routes())
.mount(
"/",
routes![
favicon,
],
)
.mount(
"/api",
routes![
cdn::upload_profile_pic,
api::profile::display_name,
// basic auth
api::auth::login,
api::auth::signup,
// 2fa
api::totp::confirm_totp,
api::totp::disable_totp,
api::totp::get_totp,
api::totp::get_totp_status,
api::totp::verify_totp,
// chat
api::chat::event_stream,
api::chat::post_message,
// user settings
api::settings::change_display_name,
api::settings::change_password,
api::settings::change_username,
api::settings::delete_account,
// spaces
api::space::list_spaces,
api::space::list_channels,
api::space::get_accessible_channels
],
)
.register(
"/",
catchers![
error::handle_401,
error::handle_404,
error::handle_default,
],
)
}
#[get("/favicon.ico")]
pub async fn favicon() -> NamedFile {
NamedFile::open("static/favicon.ico").await.unwrap()
}
-69
View File
@@ -1,69 +0,0 @@
// src/llm.rs
use serde::{Deserialize, Serialize};
use crate::messenger::ChatMsg;
#[derive(Serialize)]
struct LlmRequest {
model: String,
messages: Vec<Message>,
}
#[derive(Serialize, Deserialize)]
struct Message {
role: String, // "user" or "assistant"
content: String,
}
pub struct LlmWorker {
uri: String,
}
impl LlmWorker {
pub fn new(uri: String) -> Self {
Self { uri }
}
pub async fn query(&self, message: &ChatMsg) -> Result<ChatMsg, String> {
let client = reqwest::Client::new();
// Build the request body
let payload = LlmRequest {
model: "gpt-oss-20b".into(), // whatever model you run locally
messages: vec![Message {
role: "user".into(),
content: message.text.clone(),
}],
};
// POST to lmstudio (default 127.0.0.1:1234)
let resp = client
.post(self.uri.clone())
.json(&payload)
.send()
.await
.map_err(|_| String::from("Failed to make request to LLM server"))?;
// The API returns a JSON with `choices[].message.content`
#[derive(Deserialize)]
struct LlmResponse {
choices: Vec<Choice>,
}
#[derive(Deserialize)]
struct Choice {
message: Message,
}
let llm_resp: LlmResponse = resp
.json()
.await
.map_err(|_| String::from("Failed to make request to LLM server"))?;
Ok(ChatMsg {
display_name: Some(String::from("lmstudio")),
user_id: 0,
text: llm_resp.choices[0].message.content.clone(),
timestamp: chrono::Utc::now().timestamp_millis() as usize,
})
}
}
+9 -96
View File
@@ -1,98 +1,11 @@
// src/main.rs use backend::rocket;
#[macro_use] use backend::cli::handle_cli;
extern crate rocket;
use rocket::fs::{FileServer, NamedFile}; #[rocket::main]
use rocket::http::Method; async fn main() -> Result<(), rocket::Error> {
use rocket::{Build, Rocket}; if handle_cli().await {
use rocket_cors::{AllowedOrigins, CorsOptions}; return Ok(());
use rocket_db_pools::Database; }
use rocket_dyn_templates::Template; rocket().launch().await?;
use std::env; Ok(())
use std::sync::{Arc, LazyLock};
use crate::db::{Postgres, Redis};
pub mod auth;
pub mod cdn;
pub mod db;
pub mod handlers;
pub mod llm;
pub mod messenger;
pub mod user;
static LMSTUDIO_URL: LazyLock<String> =
LazyLock::new(|| env::var("LMSTUDIO_URL").expect("Ensure LMSTUDIO_URL is set!"));
#[launch]
fn rocket() -> Rocket<Build> {
// make sure the env is loaded
dotenv::dotenv().expect("Failed to load env! aborting launch!");
let chat = Arc::new(crate::messenger::ChatBroadcaster::new(32));
let cors = CorsOptions::default()
.allowed_origins(AllowedOrigins::all())
.allowed_methods(
vec![Method::Get, Method::Post, Method::Patch]
.into_iter()
.map(From::from)
.collect(),
)
.allow_credentials(true);
rocket::build()
.manage(chat)
.attach(cors.to_cors().unwrap())
.attach(Postgres::init())
.attach(Redis::init())
.attach(Template::fairing())
.mount("/static", FileServer::from("static"))
.mount("/cdn", cdn::routes())
.mount(
"/",
routes![
favicon,
messenger::chat_page,
auth::signup_page,
auth::login_page,
auth::mfa_page,
auth::invite_page,
],
)
.mount(
"/api",
routes![
cdn::upload_profile_pic,
messenger::post_message,
messenger::event_stream,
user::users,
user::display_name,
auth::signup,
auth::login,
auth::get_totp,
auth::confirm_totp,
auth::generate_invite,
auth::verify_totp,
auth::disable_totp,
auth::get_totp_status,
auth::change_password,
auth::change_display_name,
auth::change_username,
auth::delete_account,
],
)
.register(
"/",
catchers![
handlers::handle_404,
handlers::handle_401,
handlers::handle_default
],
)
}
#[get("/favicon.ico")]
async fn favicon() -> NamedFile {
NamedFile::open("static/favicon.ico").await.unwrap()
} }
+2 -4
View File
@@ -1,10 +1,8 @@
use redis::AsyncCommands; use redis::AsyncCommands;
use rocket_db_pools::Connection; use rocket_db_pools::Connection;
use crate::{ use crate::api::chat::ChatMsg;
db::{Postgres, Redis}, use crate::db::{Postgres, Redis};
messenger::ChatMsg,
};
// Helper function to cache message in Redis // Helper function to cache message in Redis
pub async fn insert( pub async fn insert(
-220
View File
@@ -1,220 +0,0 @@
use std::sync::Arc;
use rocket::{
Shutdown,
response::stream::{Event, EventStream},
serde::json::Json,
time::OffsetDateTime,
};
use rocket_db_pools::Connection;
use rocket_dyn_templates::{Template, context};
use serde::{Deserialize, Serialize};
use sqlx::prelude::FromRow;
use tokio::{select, sync::broadcast};
use crate::{
auth::Session,
db::{Postgres, Redis},
llm::LlmWorker,
messenger,
};
/// ---------- shared broadcaster ----------
pub struct ChatBroadcaster {
buffer_size: usize,
senders: std::sync::Mutex<std::collections::HashMap<i32, broadcast::Sender<ChatMsg>>>,
}
impl ChatBroadcaster {
pub fn new(buffer_size: usize) -> Self {
Self {
buffer_size,
senders: std::sync::Mutex::new(std::collections::HashMap::new()),
}
}
/// Publish a message to the specified channel.
pub async fn publish(&self, channel_id: i32, msg: ChatMsg) {
let mut map = self.senders.lock().unwrap();
let sender = map
.entry(channel_id)
.or_insert_with(|| broadcast::channel::<ChatMsg>(self.buffer_size).0);
let _ = sender.send(msg);
}
/// Subscribe to the specified channel.
pub fn subscribe(&self, channel_id: i32) -> broadcast::Receiver<ChatMsg> {
let mut map = self.senders.lock().unwrap();
let sender = map
.entry(channel_id)
.or_insert_with(|| broadcast::channel::<ChatMsg>(self.buffer_size).0);
sender.subscribe()
}
}
/// ---------- Rocket routes ----------
#[derive(Debug, Serialize, Deserialize, Clone, FromRow)]
pub struct ChatMsg {
pub display_name: Option<String>,
pub user_id: usize,
pub text: String,
pub timestamp: usize,
}
#[post("/chat/<channel_id>", format = "json", data = "<msg>")]
pub async fn post_message(
mut msg: Json<ChatMsg>,
chat: &rocket::State<Arc<ChatBroadcaster>>,
mut postgres: Connection<Postgres>,
mut cache: Option<Connection<Redis>>,
session: Session,
channel_id: i32,
) -> Result<(), String> {
let chat = chat.inner().clone();
let display_name = sqlx::query!(
"SELECT display_name, username FROM users WHERE id = $1",
session.user_id as i32
)
.fetch_one(&mut **postgres)
.await
.map(|row| row.display_name.unwrap_or(row.username))
.unwrap_or_else(|_| "Unknown".to_string());
msg.user_id = session.user_id;
msg.display_name = Some(display_name);
chat.publish(channel_id, msg.clone().into_inner()).await;
sqlx::query!(
"INSERT INTO messages (channel_id, user_id, content, created_at) VALUES ($1, $2, $3, $4)",
channel_id,
msg.user_id as i32,
msg.text,
OffsetDateTime::from_unix_timestamp_nanos(msg.timestamp as i128 * 1_000_000).unwrap()
)
.execute(&mut **postgres)
.await
.map_err(|_| "Failed".to_string())?;
println!("gisfujdeghnjuisdfjngiosdfgjkosdf gnojdfsg nmodfsg");
if let Some(ref mut cache) = cache {
messenger::cache::insert(cache, channel_id, &msg)
.await
.map_err(|_| "Redis cache failed".to_string())?;
}
// get response
tokio::spawn(async move {
let response = LlmWorker::new(crate::LMSTUDIO_URL.to_string())
.query(&msg)
.await;
if let Ok(reply) = response {
chat.publish(channel_id, reply.clone()).await;
if let Some(ref mut cache) = cache {
messenger::cache::insert(cache, channel_id, &reply)
.await
.ok();
}
sqlx::query!(
"INSERT INTO messages (channel_id, user_id, content, created_at) VALUES ($1, $2, $3, $4)",
channel_id,
reply.user_id as i32,
reply.text,
OffsetDateTime::from_unix_timestamp_nanos(reply.timestamp as i128 * 1_000_000).unwrap()
)
.execute(&mut **postgres)
.await
.map_err(|_| "Failed".to_string())
.unwrap();
}
});
Ok(())
}
pub async fn get_messages(
mut db: Connection<Postgres>,
mut redis: Connection<Redis>,
channel_id: i32,
) -> Json<Vec<ChatMsg>> {
if let Ok(messages) = messenger::cache::get(&mut redis, channel_id).await
&& !messages.is_empty()
{
return Json(messages);
};
if let Err(x) = messenger::cache::initialise(&mut redis, &mut db, channel_id).await {
eprintln!("WARN: {x:?}");
}
if let Ok(messages) = messenger::cache::get(&mut redis, channel_id).await
&& !messages.is_empty()
{
return Json(messages);
};
let res = sqlx::query!(
"SELECT u.username, u.display_name, u.id, m.content, m.created_at
FROM messages m
JOIN users u ON m.user_id = u.id
WHERE m.channel_id = $1
ORDER BY m.created_at DESC LIMIT 100",
channel_id
)
.fetch_all(&mut **db)
.await
.unwrap_or_else(|_| Vec::new())
.into_iter()
.rev()
.map(|msg| ChatMsg {
display_name: Some(msg.display_name.unwrap_or(msg.username)),
user_id: msg.id as usize,
text: msg.content,
timestamp: (msg.created_at.unwrap().unix_timestamp_nanos() / 1_000_000) as usize,
})
.collect();
Json(res)
}
#[get("/events/<channel_id>")]
pub async fn event_stream(
chat: &rocket::State<Arc<ChatBroadcaster>>,
postgres: Connection<Postgres>,
cache: Connection<Redis>,
_session: Session,
mut shutdown: Shutdown,
channel_id: i32,
) -> EventStream![] {
let mut rx = chat.subscribe(channel_id);
EventStream! {
// Initialize the stream with the last 100 messages
for msg in get_messages(postgres, cache, channel_id).await.0 {
yield Event::json(&msg);
}
loop {
select!{
// exit early on shutdown
_ = &mut shutdown => break,
msg = rx.recv() => match msg {
Ok(msg) => yield Event::json(&msg),
Err(broadcast::error::RecvError::Lagged(_)) => {
yield Event::comment("RecvError::Lagged");
}
Err(broadcast::error::RecvError::Closed) => break,
},
}
}
}
}
#[get("/chat")]
pub async fn chat_page(session: Session) -> Template {
Template::render("chat", context!(user_id: session.user_id))
}
+1 -4
View File
@@ -1,4 +1 @@
mod cache; // mod cache;
mod messages;
pub use messages::{ChatBroadcaster, ChatMsg, chat_page, event_stream, get_messages, post_message};
+35
View File
@@ -0,0 +1,35 @@
use rocket::serde::{Deserialize, Serialize};
use chrono::{DateTime, Utc};
#[derive(Serialize, Deserialize)]
pub struct SignupCredentials {
pub email: String,
pub username: String,
pub password: String,
pub access_token: String,
}
#[derive(Serialize, Deserialize)]
pub struct LoginCredentials {
pub username: String,
pub password: String,
}
#[derive(Serialize, Deserialize)]
pub struct AuthResponse {
pub token: String,
pub totp_required: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessTokenForm {
pub name: String,
pub max_uses: i32,
pub expiry_date: DateTime<Utc>,
pub start_date: DateTime<Utc>,
}
pub struct AccessToken {
pub id: i64,
pub code: String,
}
+3
View File
@@ -0,0 +1,3 @@
pub mod auth;
pub mod user;
pub mod space;
+35
View File
@@ -0,0 +1,35 @@
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
use chrono::{DateTime, Utc};
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct Space {
pub id: i64,
pub name: String,
pub description: Option<String>,
pub owner_id: i64,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct Channel {
pub id: i64,
pub name: String,
pub description: Option<String>,
pub space_id: i64,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpaceDto {
pub channels: Vec<Channel>,
pub id: i64,
pub owner_id: i64,
pub name: String,
pub description: Option<String>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
+52
View File
@@ -0,0 +1,52 @@
use crate::api::auth::Session;
use crate::error::ApiResult;
use crate::svc::user_svc::UserService;
use chrono::{DateTime, Utc};
use rocket::State;
use sqlx::FromRow;
use crate::api::totp::TotpStatus;
#[derive(Clone)]
#[derive(FromRow)]
pub struct User {
pub id: i64,
pub email: Option<String>,
pub username: String,
pub nickname: Option<String>,
pub passhash: String,
pub totp_status: TotpStatus,
pub totp_secret: Option<String>,
pub created_at: Option<DateTime<Utc>>,
pub updated_at: Option<DateTime<Utc>>,
}
// pub struct UserCache {}
//
// impl UserCache {
// pub async fn username(
// id: usize,
// redis_conn: &mut Connection<Redis>,
// pgsql_conn: &mut Connection<Postgres>,
// ) -> String {
// if let Ok(val) = redis_conn.get(format!("users:{id}")).await {
// return val;
// }
//
// if let Ok(v) = sqlx::query!("SELECT username FROM users WHERE id = $1", id as i32)
// .fetch_one(&mut ***pgsql_conn)
// .await
// {
// let username = v.username;
// Self::insert(id, &username, redis_conn).await;
// username
// } else {
// unimplemented!()
// }
// }
//
// pub async fn insert(id: usize, username: &str, conn: &mut Connection<Redis>) {
// conn.set_ex::<_, _, ()>(format!("users:{id}"), username.to_string(), 1800)
// .await
// .expect("failed to insert key");
// }
// }
+67
View File
@@ -0,0 +1,67 @@
use crate::repo::{Repo, AccessTokenRepoTrait};
use chrono::{DateTime, Utc};
use sqlx::PgPool;
use crate::model::auth::AccessToken;
#[derive(Clone)]
pub struct AccessTokenRepo {
pool: PgPool
}
impl Repo for AccessTokenRepo {
type Target = AccessToken;
fn new(pool: PgPool) -> Self {
Self { pool }
}
async fn get_by_id(&self, id: i64) -> Option<Self::Target> {
sqlx::query_as!(AccessToken, "SELECT id, code FROM access_tokens WHERE id = $1", id)
.fetch_optional(&self.pool)
.await
.unwrap_or(None)
}
}
#[async_trait]
impl AccessTokenRepoTrait for AccessTokenRepo {
async fn get_by_id(&self, id: i64) -> Option<AccessToken> {
Repo::get_by_id(self, id).await
}
async fn create_new(&self,
uid: i64, name: &str, code: &str, max_uses: i32,
start_date: DateTime<Utc>, expiry_date: DateTime<Utc>
) -> Result<i64, sqlx::Error> {
sqlx::query!(
"INSERT INTO access_tokens (name, code, creator_id, max_uses, created_at, expires_at)
VALUES ($1, $2, $3, $4, $5, $6) RETURNING id",
name,
code,
uid,
max_uses,
start_date,
expiry_date
).fetch_optional(&self.pool).await.and_then(|row| row.map(|r| r.id).ok_or(sqlx::Error::RowNotFound))
}
async fn use_token(&self, id: i64) -> Result<(), sqlx::Error> {
sqlx::query!("UPDATE access_tokens SET uses = uses + 1 WHERE id = $1", id)
.execute(&self.pool)
.await?;
Ok(())
}
async fn get_code_not_expired(&self, code: &str) -> Result<Option<AccessToken>, sqlx::Error> {
sqlx::query_as!(AccessToken,
"SELECT id, code FROM access_tokens
WHERE code = $1
AND created_at < NOW()
AND expires_at > NOW()
AND uses < max_uses",
code
)
.fetch_optional(&self.pool)
.await
}
}
+39
View File
@@ -0,0 +1,39 @@
use crate::repo::ChannelRepo;
use crate::model::space::Channel;
use sqlx::PgPool;
#[derive(Clone)]
pub struct ChannelRepository {
pool: PgPool,
}
impl ChannelRepository {
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
}
#[rocket::async_trait]
impl ChannelRepo for ChannelRepository {
async fn create(&self, name: &str, description: Option<&str>, space_id: i64) -> Result<i64, sqlx::Error> {
let row = sqlx::query!(
"INSERT INTO channels (name, description, space_id) VALUES ($1, $2, $3) RETURNING id",
name,
description,
space_id
)
.fetch_one(&self.pool)
.await?;
Ok(row.id)
}
async fn get_by_space_id(&self, space_id: i64) -> Result<Vec<Channel>, sqlx::Error> {
sqlx::query_as!(
Channel,
"SELECT id, name, description, space_id, created_at as \"created_at!\", updated_at as \"updated_at!\" FROM channels WHERE space_id = $1",
space_id
)
.fetch_all(&self.pool)
.await
}
}
+95
View File
@@ -0,0 +1,95 @@
use crate::api::chat::ChatMsg;
use crate::repo::Repo;
use chrono::{DateTime, Utc};
use sqlx::PgPool;
#[derive(Clone)]
pub struct MessageRepository {
pool: PgPool
}
impl Repo for MessageRepository {
type Target = ChatMsg;
fn new(pool: PgPool) -> Self {
Self { pool }
}
// TODO: caching with redis
async fn get_by_id(&self, id: i64) -> Option<Self::Target> {
sqlx::query!(
"SELECT u.username, u.nickname, u.id as user_id, m.content, m.created_at
FROM messages m
JOIN users u ON m.user_id = u.id
WHERE m.id = $1",
id
).fetch_optional(&self.pool).await.ok().flatten().map(|row| ChatMsg {
display_name: Some(row.nickname.unwrap_or(row.username)),
user_id: row.user_id,
text: row.content,
timestamp: row.created_at,
})
}
}
impl MessageRepository {
// TODO! caching with redis
pub async fn create_new(
&self, uid: i64, channel_id: i64,
text: &str, created_at: DateTime<Utc>
) -> Result<i64, sqlx::Error> {
sqlx::query!(
"INSERT INTO messages (channel_id, user_id, content, created_at)
VALUES ($1, $2, $3, $4) RETURNING id",
channel_id,
uid,
text,
created_at
).fetch_optional(&self.pool).await.and_then(|row| row.map(|r| r.id).ok_or(sqlx::Error::RowNotFound))
}
/// TODO: caching with redis
pub async fn get_by_channel(&self, channel_id: i64, limit: usize)
-> Result<Vec<ChatMsg>, sqlx::Error> {
sqlx::query!(
"SELECT u.username, u.nickname, u.id as user_id, m.content, m.created_at
FROM messages m
JOIN users u ON m.user_id = u.id
WHERE m.channel_id = $1
ORDER BY m.created_at DESC LIMIT $2",
channel_id,
limit as i64
).fetch_all(&self.pool).await.map(|messages| {
messages.into_iter().rev().map(|msg| {
ChatMsg {
display_name: Some(msg.nickname.unwrap_or(msg.username)),
user_id: msg.user_id,
text: msg.content,
timestamp: msg.created_at,
}
}).collect::<Vec<_>>()
})
}
}
+153
View File
@@ -0,0 +1,153 @@
use crate::repo::{UserRepo, AccessTokenRepoTrait};
use crate::model::user::User;
use crate::model::auth::AccessToken;
use rocket::async_trait;
use std::sync::Mutex;
use chrono::Utc;
use std::sync::Arc;
use sqlx::Error;
use crate::api::totp::TotpStatus;
use crate::api::totp::TotpStatus::Disabled;
pub struct MockAccessTokenRepo {
pub tokens: Mutex<Vec<AccessToken>>,
}
#[async_trait]
impl AccessTokenRepoTrait for MockAccessTokenRepo {
async fn get_by_id(&self, id: i64) -> Option<AccessToken> {
self.tokens.lock().unwrap().iter().find(|t| t.id == id).map(|t| AccessToken { id: t.id, code: t.code.clone() })
}
async fn create_new(&self, _uid: i64, _name: &str, code: &str, _max_uses: i32, _start_date: chrono::DateTime<Utc>, _expiry_date: chrono::DateTime<Utc>) -> Result<i64, sqlx::Error> {
let mut tokens = self.tokens.lock().unwrap();
let id = tokens.len() as i64 + 1;
tokens.push(AccessToken { id, code: code.to_string() });
Ok(id)
}
async fn use_token(&self, id: i64) -> Result<(), Error> {
// let mut tokens = self.tokens.lock().unwrap();
// if let Some(pos) = tokens.iter().position(|t| t.id == id) {
// tokens.get_mut(pos).uses =
// }
Ok(())
}
async fn get_code_not_expired(&self, code: &str) -> Result<Option<AccessToken>, Error> {
Ok(self.tokens.lock().unwrap()
.iter().find(|t| t.code == code)
.map(|t| AccessToken { id: t.id, code: t.code.clone() }))
}
}
pub struct MockUserRepo {
pub users: Mutex<Vec<User>>,
}
#[async_trait]
impl UserRepo for MockUserRepo {
fn pool(&self) -> &sqlx::PgPool {
unimplemented!("MockUserRepo does not have a real pool")
}
async fn get_by_id(&self, id: i64) -> Option<User> {
self.users.lock().unwrap().iter().find(|u| u.id == id).cloned()
}
async fn save(&self, user: &User) -> Result<(), sqlx::Error> {
let mut users = self.users.lock().unwrap();
if let Some(pos) = users.iter().position(|u| u.id == user.id) {
users[pos] = user.clone();
}
Ok(())
}
async fn new_user(&self, email: &str, username: &str, pass_hash: &str) -> Result<i64, sqlx::Error> {
let mut users = self.users.lock().unwrap();
let id = users.len() as i64 + 1;
users.push(User {
id,
email: Some(email.to_string()),
username: username.to_string(),
nickname: None,
passhash: pass_hash.to_string(),
totp_status: Disabled,
totp_secret: None,
created_at: Some(Utc::now()),
updated_at: Some(Utc::now()),
});
Ok(id)
}
async fn get_by_username(&self, username: &str) -> Result<Option<User>, sqlx::Error> {
Ok(self.users.lock().unwrap().iter().find(|u| u.username == username).cloned())
}
async fn delete_by_id(&self, id: i64) -> Result<(), sqlx::Error> {
self.users.lock().unwrap().retain(|u| u.id != id);
Ok(())
}
async fn set_display_name(&self, id: i64, display_name: Option<String>) -> Result<(), sqlx::Error> {
if let Some(u) = self.users.lock().unwrap().iter_mut().find(|u| u.id == id) {
u.nickname = display_name;
}
Ok(())
}
async fn set_username(&self, id: i64, username: &str) -> Result<(), sqlx::Error> {
if let Some(u) = self.users.lock().unwrap().iter_mut().find(|u| u.id == id) {
u.username = username.to_string();
}
Ok(())
}
async fn set_twofa_enabled(&self, id: i64, enabled: &TotpStatus) -> Result<(), sqlx::Error> {
if let Some(u) = self.users.lock().unwrap().iter_mut().find(|u| u.id == id) {
u.totp_status = *enabled;
}
Ok(())
}
async fn get_totp_secret(&self, id: i64) -> Result<Option<String>, sqlx::Error> {
Ok(self.users.lock().unwrap().iter().find(|u| u.id == id).and_then(|u| u.totp_secret.clone()))
}
async fn set_totp_secret(&self, id: i64, secret: Option<String>) -> Result<(), sqlx::Error> {
if let Some(u) = self.users.lock().unwrap().iter_mut().find(|u| u.id == id) {
u.totp_secret = secret;
}
Ok(())
}
async fn get_pass_hash(&self, id: i64) -> Result<String, sqlx::Error> {
Ok(self.users.lock().unwrap().iter().find(|u| u.id == id).map(|u| u.passhash.clone()).unwrap())
}
async fn set_pass_hash(&self, id: i64, pass_hash: &str) -> Result<(), sqlx::Error> {
if let Some(u) = self.users.lock().unwrap().iter_mut().find(|u| u.id == id) {
u.passhash = pass_hash.to_string();
}
Ok(())
}
async fn set_email(&self, id: i64, email: &str) -> Result<(), sqlx::Error> {
if let Some(u) = self.users.lock().unwrap().iter_mut().find(|u| u.id == id) {
u.email = Some(email.to_string());
}
Ok(())
}
async fn set_role(&self, _id: i64, _role: &str) -> Result<(), sqlx::Error> {
Ok(())
}
}
pub struct MockTokenRepo {
pub tokens: Mutex<Vec<AccessToken>>,
}
#[async_trait]
impl AccessTokenRepoTrait for MockTokenRepo {
async fn get_by_id(&self, id: i64) -> Option<AccessToken> {
self.tokens.lock().unwrap().iter().find(|t| t.id == id).map(|t| AccessToken { id: t.id, code: t.code.clone() })
}
async fn create_new(&self, _uid: i64, _name: &str, code: &str, _max_uses: i32, _start_date: chrono::DateTime<Utc>, _expiry_date: chrono::DateTime<Utc>) -> Result<i64, sqlx::Error> {
let mut tokens = self.tokens.lock().unwrap();
let id = tokens.len() as i64 + 1;
tokens.push(AccessToken { id, code: code.to_string() });
Ok(id)
}
async fn use_token(&self, _id: i64) -> Result<(), sqlx::Error> {
Ok(())
}
async fn get_code_not_expired(&self, code: &str) -> Result<Option<AccessToken>, sqlx::Error> {
Ok(self.tokens.lock().unwrap().iter().find(|t| t.code == code).map(|t| AccessToken { id: t.id, code: t.code.clone() }))
}
}
+64
View File
@@ -0,0 +1,64 @@
use crate::model::auth::AccessToken;
use crate::model::user::User;
use chrono::{DateTime, Utc};
use crate::api::totp::TotpStatus;
use crate::model::space::Space;
pub mod user_repo;
pub mod message_repo;
pub mod access_token_repo;
pub mod space_repo;
pub mod channel_repo;
pub mod mock;
pub trait Repo: Clone + Send + Sync {
type Target;
fn new(pool: sqlx::PgPool) -> Self;
async fn get_by_id(&self, id: i64) -> Option<Self::Target>;
}
#[rocket::async_trait]
pub trait UserRepo: Send + Sync {
fn pool(&self) -> &sqlx::PgPool;
async fn get_by_id(&self, id: i64) -> Option<User>;
async fn save(&self, user: &User) -> Result<(), sqlx::Error>;
async fn new_user(&self, email: &str, username: &str, pass_hash: &str) -> Result<i64, sqlx::Error>;
async fn get_by_username(&self, username: &str) -> Result<Option<User>, sqlx::Error>;
async fn delete_by_id(&self, id: i64) -> Result<(), sqlx::Error>;
async fn set_display_name(&self, id: i64, display_name: Option<String>) -> Result<(), sqlx::Error>;
async fn set_username(&self, id: i64, username: &str) -> Result<(), sqlx::Error>;
async fn set_twofa_enabled(&self, id: i64, enabled: &TotpStatus) -> Result<(), sqlx::Error>;
async fn get_totp_secret(&self, id: i64) -> Result<Option<String>, sqlx::Error>;
async fn set_totp_secret(&self, id: i64, secret: Option<String>) -> Result<(), sqlx::Error>;
async fn get_pass_hash(&self, id: i64) -> Result<String, sqlx::Error>;
async fn set_pass_hash(&self, id: i64, pass_hash: &str) -> Result<(), sqlx::Error>;
async fn set_email(&self, id: i64, email: &str) -> Result<(), sqlx::Error>;
async fn set_role(&self, id: i64, role: &str) -> Result<(), sqlx::Error>;
}
#[rocket::async_trait]
pub trait SpaceRepo: Send + Sync {
async fn create(&self, name: &str, description: Option<&str>, owner_id: i64) -> Result<i64, sqlx::Error>;
async fn get_all(&self) -> Result<Vec<crate::model::space::Space>, sqlx::Error>;
async fn get_by_member(&self, uid: i64) -> Result<Vec<Space>, sqlx::Error>;
async fn get_by_id(&self, id: i64) -> Result<Option<crate::model::space::Space>, sqlx::Error>;
}
#[rocket::async_trait]
pub trait ChannelRepo: Send + Sync {
async fn create(&self, name: &str, description: Option<&str>, space_id: i64) -> Result<i64, sqlx::Error>;
async fn get_by_space_id(&self, space_id: i64) -> Result<Vec<crate::model::space::Channel>, sqlx::Error>;
}
#[async_trait]
pub trait AccessTokenRepoTrait: Send + Sync {
async fn get_by_id(&self, id: i64) -> Option<AccessToken>;
async fn create_new(&self,
uid: i64, name: &str, code: &str, max_uses: i32,
start_date: DateTime<Utc>, expiry_date: DateTime<Utc>
) -> Result<i64, sqlx::Error>;
async fn use_token(&self, id: i64) -> Result<(), sqlx::Error>;
async fn get_code_not_expired(&self, code: &str) -> Result<Option<AccessToken>, sqlx::Error>;
}
+56
View File
@@ -0,0 +1,56 @@
use crate::repo::SpaceRepo;
use crate::model::space::Space;
use sqlx::PgPool;
#[derive(Clone)]
pub struct SpaceRepository {
pool: PgPool,
}
impl SpaceRepository {
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
}
#[rocket::async_trait]
impl SpaceRepo for SpaceRepository {
async fn create(&self, name: &str, description: Option<&str>, owner_id: i64) -> Result<i64, sqlx::Error> {
let row = sqlx::query!(
"INSERT INTO spaces (name, description, owner_id) VALUES ($1, $2, $3) RETURNING id",
name,
description,
owner_id
)
.fetch_one(&self.pool)
.await?;
Ok(row.id)
}
async fn get_all(&self) -> Result<Vec<Space>, sqlx::Error> {
sqlx::query_as!(Space,
"SELECT id, name, description, owner_id, created_at, updated_at FROM spaces"
)
.fetch_all(&self.pool)
.await
}
async fn get_by_member(&self, uid: i64) -> Result<Vec<Space>, sqlx::Error> {
sqlx::query_as!(Space,
"SELECT s.id, s.name, s.description, s.created_at, s.updated_at, s.owner_id
FROM spaces s JOIN space_members sm ON s.id = sm.space_id
WHERE sm.user_id = $1",
uid
).fetch_all(&self.pool)
.await
}
async fn get_by_id(&self, id: i64) -> Result<Option<Space>, sqlx::Error> {
sqlx::query_as!(Space,
"SELECT id, name, description, owner_id, created_at, updated_at FROM spaces WHERE id = $1",
id
)
.fetch_optional(&self.pool)
.await
}
}
+212
View File
@@ -0,0 +1,212 @@
use crate::repo::{Repo, UserRepo};
use crate::model::user::User;
use sqlx::PgPool;
use crate::api::totp::TotpStatus;
#[derive(Clone)]
pub struct UserRepository {
pool: PgPool
}
impl UserRepository {
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
pub fn pool(&self) -> &PgPool {
&self.pool
}
}
impl Repo for UserRepository {
type Target = User;
fn new(pool: PgPool) -> Self {
Self::new(pool)
}
async fn get_by_id(&self, id: i64) -> Option<Self::Target> {
sqlx::query_as!(
User,
"SELECT id, email, username, nickname, passhash, totp_status as \"totp_status!: TotpStatus\", totp_secret, created_at, updated_at FROM users WHERE id = $1",
id
)
.fetch_optional(&self.pool)
.await
.map_err(|e| {
tracing::error!("Database error in get_by_id: {}", e);
e
})
.ok()?
}
}
#[async_trait]
impl UserRepo for UserRepository {
fn pool(&self) -> &sqlx::PgPool {
&self.pool
}
async fn get_by_id(&self, id: i64) -> Option<User> {
Repo::get_by_id(self, id).await
}
async fn save(&self, user: &User) -> Result<(), sqlx::Error> {
sqlx::query!(
"UPDATE users SET email = $1, username = $2, nickname = $3, passhash = $4, totp_status = $5, totp_secret = $6, created_at = $7, updated_at = $8 WHERE id = $9",
user.email,
user.username,
user.nickname,
user.passhash,
user.totp_status as TotpStatus,
user.totp_secret,
user.created_at,
user.updated_at,
user.id
).execute(&self.pool).await?;
Ok(())
}
async fn new_user(&self, email: &str, username: &str, passhash: &str) -> Result<i64, sqlx::Error> {
sqlx::query!(
"INSERT INTO users (email, username, passhash) VALUES ($1, $2, $3) RETURNING id",
email,
username,
passhash
)
.fetch_optional(&self.pool)
.await
.and_then(|row| row.map(|r| r.id).ok_or(sqlx::Error::RowNotFound))
}
async fn get_by_username(&self, username: &str) -> Result<Option<User>, sqlx::Error> {
sqlx::query_as!(
User,
"SELECT id, email, username, nickname, passhash, totp_status as \"totp_status!: TotpStatus\", totp_secret, created_at, updated_at FROM users WHERE username = $1",
username
)
.fetch_optional(&self.pool)
.await
}
async fn delete_by_id(&self, id: i64) -> Result<(), sqlx::Error> {
sqlx::query!("DELETE FROM users WHERE id = $1", id)
.execute(&self.pool)
.await?;
Ok(())
}
async fn set_display_name(&self, id: i64, display_name: Option<String>) -> Result<(), sqlx::Error> {
sqlx::query!(
"UPDATE users SET nickname = $1 WHERE id = $2",
display_name,
id
)
.execute(&self.pool).await?;
Ok(())
}
async fn set_username(&self, id: i64, username: &str) -> Result<(), sqlx::Error> {
sqlx::query!(
"UPDATE users SET username = $1 WHERE id = $2",
username,
id
)
.execute(&self.pool).await?;
Ok(())
}
async fn set_twofa_enabled(&self, id: i64, enabled: &TotpStatus) -> Result<(), sqlx::Error> {
sqlx::query!(
"UPDATE users SET totp_status = $1 WHERE id = $2",
enabled as &TotpStatus,
id
)
.execute(&self.pool).await?;
Ok(())
}
async fn get_totp_secret(&self, id: i64) -> Result<Option<String>, sqlx::Error> {
sqlx::query!(
"SELECT totp_secret FROM users WHERE id = $1",
id
)
.fetch_optional(&self.pool)
.await
.map(|opt| opt.and_then(|row| row.totp_secret))
}
async fn set_totp_secret(&self, id: i64, secret: Option<String>) -> Result<(), sqlx::Error> {
sqlx::query!(
"UPDATE users SET totp_secret = $1 WHERE id = $2",
secret,
id
)
.execute(&self.pool).await?;
Ok(())
}
async fn get_pass_hash(&self, id: i64) -> Result<String, sqlx::Error> {
sqlx::query!(
"SELECT passhash FROM users WHERE id = $1",
id
)
.fetch_optional(&self.pool)
.await
.and_then(|row| row.map(|r| r.passhash).ok_or(sqlx::Error::RowNotFound))
}
async fn set_pass_hash(&self, id: i64, pass_hash: &str) -> Result<(), sqlx::Error> {
sqlx::query!(
"UPDATE users SET passhash = $1 WHERE id = $2",
pass_hash,
id
)
.execute(&self.pool).await?;
Ok(())
}
async fn set_email(&self, id: i64, email: &str) -> Result<(), sqlx::Error> {
sqlx::query!(
"UPDATE users SET email = $1 WHERE id = $2",
email,
id
)
.execute(&self.pool).await?;
Ok(())
}
async fn set_role(&self, id: i64, role: &str) -> Result<(), sqlx::Error> {
sqlx::query(
"UPDATE users SET role = $1::user_role WHERE id = $2"
)
.bind(role)
.bind(id)
.execute(&self.pool).await?;
Ok(())
}
}
+47
View File
@@ -0,0 +1,47 @@
use std::sync::Arc;
use chrono::{DateTime, Utc};
use uuid::Uuid;
use crate::error::{ApiResult, AppError};
use crate::model::auth::AccessToken;
use crate::repo::access_token_repo::AccessTokenRepo;
use crate::repo::AccessTokenRepoTrait;
#[derive(Clone)]
pub struct AccessTokenService {
repo: Arc<dyn AccessTokenRepoTrait>
}
impl AccessTokenService {
pub fn new(repo: Arc<dyn AccessTokenRepoTrait>) -> Self {
Self { repo }
}
pub async fn create(&self,
uid: i64, name: &str, max_uses: i32,
valid_from: DateTime<Utc>, valid_until: DateTime<Utc>
) -> ApiResult<String> {
if valid_from > valid_until {
return Err(AppError::bad_request("start date must be before end date"))
}
if valid_until < Utc::now() {
return Err(AppError::bad_request("expiry date must be after current date"))
}
let code = Uuid::new_v4().to_string();
self.repo.create_new(uid, name, &code, max_uses, valid_from, valid_until).await?;
Ok(code)
}
pub async fn get_valid_token(&self, token: &str) -> ApiResult<AccessToken> {
self.repo.get_code_not_expired(token).await?
.ok_or(AppError::unauthorised("invalid access token"))
}
pub async fn use_token(&self, id: i64) -> ApiResult<()> {
self.repo.use_token(id).await?;
Ok(())
}
}
+259
View File
@@ -0,0 +1,259 @@
use crate::api::auth::{Claims, TokenScope};
use crate::api::totp::totp_gen;
use crate::error::{ApiResult, AppError};
use crate::model::auth::AuthResponse;
use crate::repo::{UserRepo, AccessTokenRepoTrait};
use std::sync::Arc;
use argon2::password_hash::rand_core::OsRng;
use argon2::password_hash::SaltString;
use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
use chrono::{DateTime, Utc};
use uuid::Uuid;
use crate::api::totp::TotpStatus::{Disabled, Enabled};
use crate::svc::access_token_svc::AccessTokenService;
#[derive(Clone)]
pub struct AuthService {
users: Arc<dyn UserRepo>,
tokens: AccessTokenService,
}
impl AuthService {
pub fn new(users: Arc<dyn UserRepo>, tokens: AccessTokenService) -> Self {
Self { users, tokens }
}
pub async fn signup(&self,
email: &str, username: &str,
password: &str, access_token: &str
) -> ApiResult<AuthResponse> {
let tok_id = self.tokens.get_valid_token(access_token).await?.id;
let pass = password.to_string();
let svc = self.clone();
let hashed = tokio::task::spawn_blocking(move || svc.hash_password(&pass))
.await
.map_err(|_| AppError::internal("blocking task panicked"))??;
let uid = self.users
.new_user(email, username, &hashed).await?;
self.tokens.use_token(tok_id).await?;
let jwt = Claims::new(uid as usize, TokenScope::Full).encode();
Ok(AuthResponse {
token: jwt,
totp_required: false
})
}
pub async fn login(&self, username: &str, password: &str) -> ApiResult<AuthResponse> {
let user = self.users
.get_by_username(username).await?
.ok_or(AppError::unauthorised("invalid username"))?;
let pass = password.to_string();
let user_hash = user.passhash.clone();
let svc = self.clone();
tokio::task::spawn_blocking(move || svc.verify_password(&user_hash, &pass))
.await
.map_err(|_| AppError::internal("blocking task panicked"))??;
let scope = if user.totp_status == Enabled { TokenScope::TotpPending } else { TokenScope::Full };
let jwt = Claims::new(user.id as usize, scope).encode();
Ok(AuthResponse {
token: jwt,
totp_required: user.totp_status == Enabled
})
}
pub async fn login_totp(&self, uid: i64, code: &str) -> ApiResult<AuthResponse> {
let secret = self.users.get_totp_secret(uid).await?
.ok_or(AppError::unauthorised("2fa not enabled"))?;
self.verify_2fa(uid, &secret, code)?;
let jwt = Claims::new(uid as usize, TokenScope::Full).encode();
Ok(AuthResponse {
token: jwt,
totp_required: false
})
}
pub async fn disable_totp(&self, uid: i64, password: &str, totp_code: &str) -> ApiResult<AuthResponse> {
let mut user = self.users.get_by_id(uid).await
.ok_or(AppError::internal("user not found"))?;
let Some(secret) = user.totp_secret else {
return Err(AppError::bad_request("2fa not enabled"));
};
self.verify_password(&user.passhash, password)?;
self.verify_2fa(uid, &secret, totp_code)?;
user.totp_secret = None;
user.totp_status = Disabled;
self.users.save(&user).await?;
Ok(AuthResponse {
token: Claims::new(uid as usize, TokenScope::Full).encode(),
totp_required: false
})
}
pub async fn get_totp_status(&self, uid: i64) -> ApiResult<bool> {
Ok(
self.users.get_totp_secret(uid).await?.is_some()
)
}
pub async fn confirm_totp(&self, uid: i64, totp_code: &str) -> ApiResult<()> {
let secret = self.users.get_totp_secret(uid).await?
.ok_or(AppError::bad_request("2fa setup not initialised"))?;
self.verify_2fa(uid, &secret, totp_code)?;
self.users.set_twofa_enabled(uid, &Enabled).await?;
Ok(())
}
pub async fn get_or_create_totp_secret(
&self, uid: i64, password: &str,
) -> ApiResult<String> {
let user = self.users.get_by_id(uid).await
.ok_or(AppError::internal("user not found"))?;
let pass = password.to_string();
let user_hash = user.passhash.clone();
let svc = self.clone();
tokio::task::spawn_blocking(move || svc.verify_password(&user_hash, &pass))
.await
.map_err(|_| AppError::internal("blocking task panicked"))??;
if let Some(secret) = user.totp_secret {
return Ok(secret);
}
let new_secret = totp_rs::Secret::generate_secret()
.to_encoded()
.to_string();
self.users.set_totp_secret(uid, Some(new_secret.clone())).await?;
Ok(new_secret)
}
pub async fn verify_user_password(&self, uid: i64, password: &str) -> ApiResult<()> {
let hash = self.users.get_pass_hash(uid).await
.map_err(|_| AppError::internal("user not found"))?;
let pass = password.to_string();
let svc = self.clone();
tokio::task::spawn_blocking(move || svc.verify_password(&hash, &pass))
.await
.map_err(|_| AppError::internal("blocking task panicked"))??;
Ok(())
}
pub async fn verify_user_totp(&self, uid: i64, totp_code: &str) -> ApiResult<()> {
let secret = self.users.get_totp_secret(uid).await?
.ok_or(AppError::internal("user not found"))?;
self.verify_2fa(uid, &secret, totp_code)
}
pub fn hash_password(&self, password: &str) -> ApiResult<String> {
let salt = SaltString::generate(&mut OsRng);
Argon2::default()
.hash_password(password.as_bytes(), &salt)
.map_err(|_| AppError::internal("failed to hash password"))
.map(|hash| hash.to_string())
}
// Private helpers
fn verify_password(&self, pass_hash: &str, password: &str) -> ApiResult<()> {
let parsed_hash = PasswordHash::new(&pass_hash)
.map_err(|_| AppError::internal("invalid password hash"))?;
Argon2::default()
.verify_password(password.as_bytes(), &parsed_hash)
.map_err(|_| AppError::unauthorised("incorrect password"))?;
Ok(())
}
pub fn verify_2fa(&self, uid: i64, totp_secret: &str, totp_code: &str) -> ApiResult<()> {
if totp_gen(uid, totp_secret.as_bytes())
.map_err(|_| AppError::internal("invalid totp secret"))?
.check_current(totp_code)
.map_err(|_| AppError::internal("invalid totp code"))? {
Ok(())
} else {
Err(AppError::unauthorised("incorrect totp code"))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::repo::mock::{MockUserRepo, MockTokenRepo};
use std::sync::Mutex;
fn setup() -> AuthService {
unsafe {
std::env::set_var("JWT_SECRET", "test_secret");
}
let users = Arc::new(MockUserRepo { users: Mutex::new(vec![]) });
let tok_repo = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) });
let tokens = AccessTokenService::new(tok_repo);
AuthService::new(users, tokens)
}
#[tokio::test]
async fn test_signup_and_login() {
let auth = setup();
let code = auth.tokens.create(1, "test", 1, Utc::now(), Utc::now()).await.unwrap();
let signup_res = auth.signup("test@example.com", "tester", "password123", &code).await;
assert!(signup_res.is_ok());
let login_res = auth.login("tester", "password123").await;
assert!(login_res.is_ok());
let login_data = login_res.unwrap();
assert!(!login_data.totp_required);
assert!(!login_data.token.is_empty());
}
#[tokio::test]
async fn test_login_invalid_password() {
let auth = setup();
let token_code = auth.tokens.create(1, "test", 1, Utc::now(), Utc::now()).await.unwrap();
auth.signup("test@example.com", "tester", "password123", &token_code).await.unwrap();
let login_res = auth.login("tester", "wrong_password").await;
assert!(login_res.is_err());
if let Err(AppError::Unauthorised(msg)) = login_res {
assert_eq!(msg, "incorrect password");
} else {
panic!("Expected Unauthorised error");
}
}
#[tokio::test]
async fn test_invite() {
let auth = setup();
let res = auth.tokens.create(1, "invite", 1, Utc::now(), Utc::now() + chrono::Duration::days(1)).await;
assert!(res.is_ok());
let code = res.unwrap();
assert!(!code.is_empty());
let token = auth.tokens.get_valid_token(&code).await;
assert!(token.is_ok());
}
}
+187
View File
@@ -0,0 +1,187 @@
use crate::api::chat::ChatMsg;
use crate::error::{ApiResult, AppError};
use crate::repo::message_repo::MessageRepository;
use crate::repo::{ChannelRepo, Repo, SpaceRepo, UserRepo};
use chrono::{DateTime, Utc};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::broadcast::Sender;
use tokio::sync::{broadcast, Mutex};
use crate::model::space::SpaceDto;
use crate::svc::llm_service::LlmService;
/// ---------- shared broadcaster ----------
#[derive(Clone)]
pub struct ChatService {
users: Arc<dyn UserRepo>,
channels: Arc<dyn ChannelRepo>,
spaces: Arc<dyn SpaceRepo>,
messages: MessageRepository,
llm: LlmService,
buffer_size: usize,
senders: Arc<Mutex<HashMap<i64, Sender<ChatMsg>>>>,
}
impl ChatService {
pub fn new(
buffer_size: usize, llm: LlmService,
messages: MessageRepository, users: Arc<dyn UserRepo>,
channels: Arc<dyn ChannelRepo>, spaces: Arc<dyn SpaceRepo>,
) -> Self {
Self {
channels,
spaces,
llm,
users,
messages,
buffer_size,
senders: Arc::new(Mutex::new(std::collections::HashMap::new())),
}
}
pub async fn get_accessible_channels(&self, uid: i64) -> ApiResult<Vec<SpaceDto>> {
// let spaces = self.spaces.get_by_member(uid).await?;
// TODO! UNCOMMENT THIS ^^^^^^
let spaces = self.spaces.get_all().await?;
let mut result = Vec::new();
for space in spaces {
let channels = self.channels.get_by_space_id(space.id).await?;
result.push(SpaceDto {
channels,
id: space.id,
owner_id: space.owner_id,
name: space.name,
description: space.description,
created_at: space.created_at,
updated_at: space.updated_at,
});
}
Ok(result)
}
pub async fn get_messages(&self, channel_id: i64, limit: usize) -> ApiResult<Vec<ChatMsg>> {
let messages = self.messages.get_by_channel(channel_id, limit).await?;
Ok(messages)
}
/// Sends a chat message to the specified channel, persists it to the database,
/// and handles potential AI-generated replies asynchronously.
///
/// # Parameters
/// - `channel_id`: The ID of the channel to which the message will be sent.
/// - `uid`: The user ID of the sender.
/// - `text`: The content of the message to be sent.
/// - `created_at`: The timestamp at which the message was created.
///
/// # Returns
/// - `ApiResult<()>`: Indicates success or failure of the operation.
///
/// # Behavior
/// 1. Fetches the user by their `uid`. Returns an error if the user is not found.
/// 2. Constructs a `ChatMsg` object with the sender's `display_name` or `username`,
/// and the specified message content and timestamp.
/// 3. Publishes the constructed message to the given channel.
/// 4. Persists the message in the database.
/// 5. Spawns an asynchronous task to generate an LLM-powered (language model) reply:
/// - Sends the original message to the LLM worker for a potential reply.
/// - Publishes the LLM's reply to the same channel if successful.
/// - Persists the LLM's reply to the database.
///
/// # Notes
/// - Caching with Redis is planned for both message persistence and AI replies, but
/// is not implemented in the current version.
/// - The spawned asynchronous task does not block the main execution flow.
///
/// # Potential Errors
/// - Returns `AppError::NotFound` if the `uid` does not map to an existing user.
/// - Returns an error wrapped in `ApiResult` if the database operations fail.
///
/// # TODO
/// - Implement caching for both user-supplied messages and LLM-generated replies
/// using Redis at the repository or service layer.
pub async fn send(&self,
channel_id: i64, uid: i64,
text: &str, created_at: DateTime<Utc>
) -> ApiResult<()> {
let user = self.users.get_by_id(uid).await
.ok_or(AppError::NotFound)?;
let message = ChatMsg {
display_name: Some(user
.nickname.clone()
.unwrap_or_else(|| user.username.clone())),
user_id: uid,
text: text.to_string(),
timestamp: created_at,
};
self.publish(channel_id, message.clone()).await;
let _msg_id = self.messages.create_new(uid, channel_id, text, created_at).await?;
// TODO: caching w redis at repository layer
let svc_instance = self.clone();
let Some(text) = text.strip_prefix("/ask ") else {
return Ok(())
};
if !svc_instance.llm.enabled() {
return Ok(())
}
tokio::spawn(async move {
let response = svc_instance.llm
.query(&message)
.await;
if let Ok(reply) = response {
tracing::info!("LLM reply: {}", reply.text);
svc_instance.publish(channel_id, reply.clone()).await;
// TODO: cache response (or do with redis!)
if let Err(e) = svc_instance.messages
.create_new(reply.user_id, channel_id, &reply.text, reply.timestamp).await {
tracing::error!("Failed to persist LLM reply: {}", e);
}
tracing::info!("LLM reply persisted");
} else {
tracing::warn!("Error contacting LLM: {:?}", response);
}
});
Ok(())
}
/// Subscribe to the specified channel.
pub async fn subscribe(&self, channel_id: i64) -> broadcast::Receiver<ChatMsg> {
let mut map = self.senders.lock().await;
let sender = map
.entry(channel_id)
.or_insert_with(|| broadcast::channel::<ChatMsg>(self.buffer_size).0);
sender.subscribe()
}
// Private helper methods
/// Publish a message to the specified channel.
async fn publish(&self, channel_id: i64, msg: ChatMsg) {
let mut map = self.senders.lock().await;
let sender = map
.entry(channel_id)
.or_insert_with(|| broadcast::channel::<ChatMsg>(self.buffer_size).0);
let _ = sender.send(msg);
}
}
+89
View File
@@ -0,0 +1,89 @@
#[derive(Clone)]
pub struct LlmService;
static LMSTUDIO_URL: LazyLock<Option<String>> = LazyLock::new(|| env::var("LMSTUDIO_URL").ok());
static LMSTUDIO_MODEL: LazyLock<Option<String>> = LazyLock::new(|| env::var("LMSTUDIO_MODEL").ok());
impl LlmService {
pub fn new() -> Self {
Self {}
}
pub fn enabled(&self) -> bool {
LMSTUDIO_URL.is_some()
}
pub async fn query(&self, message: &ChatMsg) -> ApiResult<ChatMsg> {
let Some(url) = LMSTUDIO_URL.clone() else {
return Err(AppError::internal("AI not enabled!"))
};
let model = LMSTUDIO_MODEL.clone().unwrap_or_else(|| "gpt-oss-20b".into());
let client = reqwest::Client::new();
// Build the request body
let payload = LlmRequest {
model, // whatever model you run locally
messages: vec![Message {
role: "user".into(),
content: message.text.clone(),
}],
};
// POST to lmstudio (default 127.0.0.1:1234)
let resp = client
.post(url)
.json(&payload)
.send()
.await
.map_err(|_| AppError::internal("Failed to make request to LLM server"))?;
// The API returns a JSON with `choices[].message.content`
#[derive(Deserialize)]
struct LlmResponse {
choices: Vec<Choice>,
}
#[derive(Deserialize)]
struct Choice {
message: Message,
}
let llm_resp: LlmResponse = resp
.json()
.await
.map_err(|_| AppError::internal("Failed to parse LLM response"))?;
Ok(ChatMsg {
display_name: Some(String::from("llm")),
user_id: 0,
text: llm_resp.choices[0].message.content.clone(),
timestamp: chrono::Utc::now(),
})
}
}
use std::env;
use std::sync::LazyLock;
// src/llm.rs
use serde::{Deserialize, Serialize};
use crate::api::chat::ChatMsg;
use crate::error::{ApiResult, AppError};
use crate::svc::chat_svc::ChatService;
#[derive(Serialize)]
struct LlmRequest {
model: String,
messages: Vec<Message>,
}
#[derive(Serialize, Deserialize)]
struct Message {
role: String, // "user" or "assistant"
content: String,
}
+6
View File
@@ -0,0 +1,6 @@
pub mod auth_svc;
pub mod chat_svc;
pub mod settings_svc;
pub mod user_svc;
pub mod access_token_svc;
pub mod llm_service;
+116
View File
@@ -0,0 +1,116 @@
//! The `SettingsService` is responsible for managing user account settings, allowing users to
//! update their username, password, display name, email, and delete their account.
//! It interacts with the `AuthService` to handle authentication and password-related functionality
//! and the `UserRepository` to perform updates to user accounts in the data store.
use crate::error::{ApiResult, AppError};
use crate::repo::UserRepo;
use crate::svc::auth_svc::AuthService;
use std::sync::Arc;
#[derive(Clone)]
pub struct SettingsService {
auth: AuthService,
users: Arc<dyn UserRepo>,
}
impl SettingsService {
pub fn new(auth: AuthService, users: Arc<dyn UserRepo>) -> Self {
Self { auth, users }
}
pub async fn change_username(&self, uid: i64, new: &str) -> ApiResult<()> {
self.users.set_username(uid, new).await?;
Ok(())
}
pub async fn change_password(&self, uid: i64, old: &str, new: &str) -> ApiResult<()> {
self.auth.verify_user_password(uid, old).await?;
let hashed = self.auth.hash_password(new)?;
self.users.set_pass_hash(uid, &hashed).await?;
Ok(())
}
pub async fn change_display_name(&self, uid: i64, new: Option<String>) -> ApiResult<()> {
self.users.set_display_name(uid, new).await?;
Ok(())
}
pub async fn change_email(&self, uid: i64, new: &str) -> ApiResult<()> {
self.users.set_email(uid, new).await?;
Ok(())
}
pub async fn delete_account(&self, uid: i64, password: &str, totp_code: &Option<String>) -> ApiResult<()> {
self.auth.verify_user_password(uid, password).await?;
// check 2fa code is correct if enabled
if self.auth.get_totp_status(uid).await? {
let Some(totp_code) = totp_code else {
return Err(AppError::unauthorised("2fa code is required"))
};
self.auth.verify_user_totp(uid, totp_code).await?;
}
self.users.delete_by_id(uid).await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::repo::mock::{MockUserRepo, MockTokenRepo};
use std::sync::Mutex;
use chrono::Utc;
use crate::svc::access_token_svc::AccessTokenService;
fn setup() -> SettingsService {
unsafe {
std::env::set_var("JWT_SECRET", "test_secret");
}
let users = Arc::new(MockUserRepo { users: Mutex::new(vec![]) });
let token_repo = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) });
let tokens_svc = AccessTokenService::new(token_repo);
let auth = AuthService::new(users.clone(), tokens_svc.clone());
SettingsService::new(auth, users)
}
#[tokio::test]
async fn test_change_username() {
let settings = setup();
let uid = settings.users.new_user("test@example.com", "old", "pass").await.unwrap();
settings.change_username(uid, "new").await.unwrap();
let user = settings.users.get_by_id(uid).await.unwrap();
assert_eq!(user.username, "new");
}
#[tokio::test]
async fn test_change_password() {
let settings = setup();
let pass = "old_pass";
let hashed = settings.auth.hash_password(pass).unwrap();
let uid = settings.users.new_user("test@example.com", "user", &hashed).await.unwrap();
settings.change_password(uid, pass, "new_pass").await.unwrap();
let _user = settings.users.get_by_id(uid).await.unwrap();
assert!(settings.auth.verify_user_password(uid, "new_pass").await.is_ok());
}
#[tokio::test]
async fn test_delete_account() {
let settings = setup();
let pass = "password";
let hashed = settings.auth.hash_password(pass).unwrap();
let uid = settings.users.new_user("test@example.com", "user", &hashed).await.unwrap();
let res = settings.delete_account(uid, pass, &None).await;
assert!(res.is_ok());
let user = settings.users.get_by_id(uid).await;
assert!(user.is_none());
}
}
+28
View File
@@ -0,0 +1,28 @@
use crate::error::ApiResult;
use crate::repo::UserRepo;
use std::sync::Arc;
pub struct UserService {
repo: Arc<dyn UserRepo>
}
impl UserService {
pub fn new(repo: Arc<dyn UserRepo>) -> Self {
Self { repo }
}
pub async fn get_display_name(&self, uid: i64) -> ApiResult<String> {
// TODO: redis caching for display names
let user = self.repo.get_by_id(uid)
.await.ok_or(crate::error::AppError::NotFound)?;
Ok(user.nickname.unwrap_or_else(|| user.username))
}
pub async fn get_username(&self, uid: i64) -> ApiResult<String> {
self.repo.get_by_id(uid)
.await.ok_or(crate::error::AppError::NotFound)
.map(|u| u.username)
}
}
-188
View File
@@ -1,188 +0,0 @@
use argon2::{Argon2, PasswordHash, PasswordVerifier};
use redis::AsyncCommands;
use rocket::{http::Status, serde::json::Json, time::OffsetDateTime};
use rocket_db_pools::Connection;
use crate::{
auth::{Session, two_factor::totp_gen},
db::{Postgres, Redis},
};
pub struct User {
pub id: i32,
pub email: Option<String>,
pub username: String,
pub display_name: Option<String>,
pub pass_hash: String,
pub twofa_enabled: bool,
pub totp_secret: Option<String>,
pub created_at: Option<OffsetDateTime>,
pub updated_at: Option<OffsetDateTime>,
}
impl User {
pub async fn get_by_id(id: usize, db: &mut Connection<Postgres>) -> Option<Self> {
sqlx::query_as!(
Self,
"SELECT id, email, username, display_name, pass_hash, twofa_enabled, totp_secret, created_at, updated_at FROM users WHERE id = $1",
id as i32
)
.fetch_optional(&mut ***db)
.await
.unwrap_or(None)
}
pub async fn delete(&mut self, db: &mut Connection<Postgres>) -> Result<(), sqlx::Error> {
sqlx::query!("DELETE FROM users WHERE id = $1", self.id)
.execute(&mut ***db)
.await?;
Ok(())
}
pub fn verify_2fa(&self, code: &str) -> Result<(), Status> {
if totp_gen(
self.id as usize,
self.totp_secret
.clone()
.expect("user with 2fa enabled has no totp secret")
.as_bytes(),
)
.map_err(|_| Status::InternalServerError)?
.check_current(code)
.map_err(|_| Status::InternalServerError)?
{
Ok(())
} else {
Err(Status::Unauthorized)
}
}
pub fn verify_password(&self, password: &str) -> Result<(), Status> {
let parsed_hash = PasswordHash::new(&self.pass_hash)
.inspect_err(|e| {
tracing::error!("Failed to parse hash for password! uid:{} {e}", self.id)
})
.map_err(|_| Status::InternalServerError)?;
Argon2::default()
.verify_password(password.as_bytes(), &parsed_hash)
.map_err(|_| Status::Unauthorized)
}
pub async fn set_display_name(
&mut self,
display_name: Option<String>,
db: &mut Connection<Postgres>,
) -> Result<(), sqlx::Error> {
self.display_name = display_name;
sqlx::query!(
"UPDATE users SET display_name = $1 WHERE id = $2",
self.display_name,
self.id
)
.execute(&mut ***db)
.await?;
Ok(())
}
pub async fn set_username(
&mut self,
username: String,
db: &mut Connection<Postgres>,
) -> Result<(), sqlx::Error> {
self.username = username;
sqlx::query!(
"UPDATE users SET username = $1 WHERE id = $2",
self.username,
self.id
)
.execute(&mut ***db)
.await?;
Ok(())
}
pub async fn set_twofa_enabled(
&mut self,
enabled: bool,
db: &mut Connection<Postgres>,
) -> Result<(), sqlx::Error> {
self.twofa_enabled = enabled;
sqlx::query!(
"UPDATE users SET twofa_enabled = $1 WHERE id = $2",
self.twofa_enabled,
self.id
)
.execute(&mut ***db)
.await?;
Ok(())
}
pub async fn set_pass_hash(
&mut self,
pass_hash: String,
db: &mut Connection<Postgres>,
) -> Result<(), sqlx::Error> {
self.pass_hash = pass_hash;
sqlx::query!(
"UPDATE users SET pass_hash = $1 WHERE id = $2",
self.pass_hash,
self.id
)
.execute(&mut ***db)
.await?;
Ok(())
}
}
#[get("/users", rank = 2)]
pub async fn users(_ag: Session, mut db: Connection<Postgres>) -> Json<Vec<i32>> {
sqlx::query!("SELECT id FROM users")
.fetch_all(&mut **db)
.await
.unwrap_or_else(|_| Vec::new())
.into_iter()
.map(|row| row.id)
.collect::<Vec<i32>>()
.into()
}
#[get("/users/<id>", rank = 1)]
pub async fn display_name(
id: usize,
_ag: Session,
mut pgsql_conn: Connection<Postgres>,
mut redis_conn: Connection<Redis>,
) -> String {
UserCache::username(id, &mut redis_conn, &mut pgsql_conn).await
}
pub struct UserCache {}
impl UserCache {
pub async fn username(
id: usize,
redis_conn: &mut Connection<Redis>,
pgsql_conn: &mut Connection<Postgres>,
) -> String {
if let Ok(val) = redis_conn.get(format!("users:{id}")).await {
return val;
}
if let Ok(v) = sqlx::query!("SELECT username FROM users WHERE id = $1", id as i32)
.fetch_one(&mut ***pgsql_conn)
.await
{
let username = v.username;
Self::insert(id, &username, redis_conn).await;
username
} else {
unimplemented!()
}
}
pub async fn insert(id: usize, username: &str, conn: &mut Connection<Redis>) {
conn.set_ex::<_, _, ()>(format!("users:{id}"), username.to_string(), 1800)
.await
.expect("failed to insert key");
}
}
+198
View File
@@ -0,0 +1,198 @@
use backend::rocket_builder;
use backend::repo::mock::{MockUserRepo, MockTokenRepo};
use backend::repo::message_repo::MessageRepository;
use backend::svc::chat_svc::ChatService;
use backend::repo::user_repo::UserRepository;
use backend::repo::{Repo, AccessTokenRepoTrait};
use rocket::local::asynchronous::Client;
use rocket::http::{Status, ContentType};
use serde_json::json;
use std::sync::{Arc, Mutex};
use sqlx::PgPool;
use chrono::Utc;
use backend::svc::llm_service::LlmService;
async fn test_rocket() -> rocket::Rocket<rocket::Build> {
let users = Arc::new(MockUserRepo { users: Mutex::new(vec![]) });
let tokens = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) });
let pool = PgPool::connect_lazy("postgres://localhost/unused").unwrap();
let messages = MessageRepository::new(pool.clone());
let user_repo = Arc::new(UserRepository::new(pool));
let llm_service = LlmService::new();
let chat_service = ChatService::new(32, messages, user_repo, llm_service);
rocket_builder(users, tokens, chat_service)
}
#[rocket::async_test]
async fn test_unauthorized_access() {
let client = Client::tracked(test_rocket().await).await.expect("valid rocket instance");
// Attempt to access a protected endpoint without authentication
let response = client.patch("/api/settings/display_name").dispatch().await;
assert_eq!(response.status(), Status::Unauthorized);
let response = client.post("/api/settings/password").dispatch().await;
assert_eq!(response.status(), Status::Unauthorized);
let response = client.delete("/api/settings").dispatch().await;
assert_eq!(response.status(), Status::Unauthorized);
}
#[rocket::async_test]
async fn test_signup_invalid_token() {
let client = Client::tracked(test_rocket().await).await.expect("valid rocket instance");
let signup_data = json!({
"email": "test@example.com",
"username": "testuser",
"password": "password123",
"access_token": "invalid-token"
});
let response = client.post("/api/signup")
.header(ContentType::JSON)
.body(signup_data.to_string())
.dispatch()
.await;
assert_eq!(response.status(), Status::Unauthorized);
}
#[rocket::async_test]
async fn test_login_invalid_credentials() {
let client = Client::tracked(test_rocket().await).await.expect("valid rocket instance");
let login_data = json!({
"username": "nonexistent",
"password": "wrongpassword"
});
let response = client.post("/api/login")
.header(ContentType::JSON)
.body(login_data.to_string())
.dispatch()
.await;
assert_eq!(response.status(), Status::Unauthorized);
}
#[rocket::async_test]
async fn test_full_auth_flow() {
let users = Arc::new(MockUserRepo { users: Mutex::new(vec![]) });
let tokens = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) });
let pool = PgPool::connect_lazy("postgres://localhost/unused").unwrap();
let messages = MessageRepository::new(pool.clone());
let user_repo = Arc::new(UserRepository::new(pool));
let llm_service = LlmService::new();
let chat_service = ChatService::new(32, messages, user_repo, llm_service);
let token_code = "valid-token";
tokens.create_new(1, "test", token_code, 1, Utc::now(), Utc::now()).await.unwrap();
let client = Client::tracked(rocket_builder(users, tokens, chat_service)).await.expect("valid rocket instance");
// 1. Signup
let signup_data = json!({
"email": "test@example.com",
"username": "testuser",
"password": "password123",
"access_token": token_code
});
let response = client.post("/api/signup")
.header(ContentType::JSON)
.body(signup_data.to_string())
.dispatch()
.await;
assert_eq!(response.status(), Status::Ok);
let body = response.into_string().await.unwrap();
assert!(body.contains("token"));
// 2. Login
let login_data = json!({
"username": "testuser",
"password": "password123"
});
let response = client.post("/api/login")
.header(ContentType::JSON)
.body(login_data.to_string())
.dispatch()
.await;
assert_eq!(response.status(), Status::Ok);
let body = response.into_string().await.unwrap();
assert!(body.contains("token"));
}
#[rocket::async_test]
async fn test_delete_account_security() {
let users = Arc::new(MockUserRepo { users: Mutex::new(vec![]) });
let tokens = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) });
let pool = PgPool::connect_lazy("postgres://localhost/unused").unwrap();
let messages = MessageRepository::new(pool.clone());
let user_repo = Arc::new(UserRepository::new(pool));
let llm_service = LlmService::new();
let chat_service = ChatService::new(32, messages, user_repo, llm_service);
let client = Client::tracked(rocket_builder(users.clone(), tokens.clone(), chat_service)).await.expect("valid rocket instance");
let token_code = "valid-token";
tokens.create_new(1, "test", token_code, 1, Utc::now(), Utc::now()).await.unwrap();
client.post("/api/signup")
.header(ContentType::JSON)
.body(json!({
"email": "test@example.com",
"username": "testuser",
"password": "password123",
"access_token": token_code
}).to_string())
.dispatch()
.await;
// Login to get JWT
let login_res = client.post("/api/login")
.header(ContentType::JSON)
.body(json!({
"username": "testuser",
"password": "password123"
}).to_string())
.dispatch()
.await;
let auth_resp: serde_json::Value = serde_json::from_str(&login_res.into_string().await.unwrap()).unwrap();
let jwt = auth_resp["token"].as_str().unwrap();
// 1. Delete with WRONG password
let response = client.delete("/api/settings")
.header(ContentType::JSON)
.header(rocket::http::Header::new("Authorization", format!("Bearer {}", jwt)))
.body(json!({
"password": "wrongpassword",
"totp_code": null
}).to_string())
.dispatch()
.await;
assert_eq!(response.status(), Status::Unauthorized);
// 2. Delete with CORRECT password
let response = client.delete("/api/settings")
.header(ContentType::JSON)
.header(rocket::http::Header::new("Authorization", format!("Bearer {}", jwt)))
.body(json!({
"password": "password123",
"totp_code": null
}).to_string())
.dispatch()
.await;
assert_eq!(response.status(), Status::Ok);
// Verify user is gone
assert!(users.users.lock().unwrap().is_empty());
}
+142
View File
@@ -0,0 +1,142 @@
use backend::rocket_builder;
use backend::repo::mock::{MockUserRepo, MockTokenRepo};
use backend::repo::message_repo::MessageRepository;
use backend::svc::chat_svc::ChatService;
use backend::repo::{Repo, AccessTokenRepoTrait};
use rocket::local::asynchronous::Client;
use rocket::http::{Status, ContentType, Header};
use serde_json::{json, Value};
use std::sync::{Arc, Mutex};
use sqlx::PgPool;
use chrono::Utc;
use backend::svc::llm_service::LlmService;
async fn setup_client_with_svc(chat_service: ChatService, users: Arc<MockUserRepo>, tokens: Arc<MockTokenRepo>) -> (Client, String) {
let client = Client::tracked(rocket_builder(users.clone(), tokens.clone(), chat_service)).await.expect("valid rocket instance");
// Create a user and get JWT
let token_code = "valid-token";
tokens.create_new(1, "test", token_code, 1, Utc::now(), Utc::now()).await.unwrap();
let jwt = {
let signup_res = client.post("/api/signup")
.header(ContentType::JSON)
.body(json!({
"email": "test@example.com",
"username": "testuser",
"password": "password123",
"access_token": token_code
}).to_string())
.dispatch()
.await;
assert_eq!(signup_res.status(), Status::Ok);
let login_res = client.post("/api/login")
.header(ContentType::JSON)
.body(json!({
"username": "testuser",
"password": "password123"
}).to_string())
.dispatch()
.await;
assert_eq!(login_res.status(), Status::Ok, "Login failed");
let body = login_res.into_string().await.expect("login body");
let auth_resp: serde_json::Value = serde_json::from_str(&body).unwrap();
auth_resp["token"].as_str().unwrap().to_string()
};
(client, jwt)
}
#[rocket::async_test]
async fn test_chat_event_stream_consistency() {
unsafe { std::env::set_var("JWT_SECRET", "test_secret"); }
let pool = PgPool::connect_lazy("postgres://localhost/unused").unwrap();
let messages = <MessageRepository as Repo>::new(pool.clone());
let users_repo = Arc::new(MockUserRepo { users: Mutex::new(vec![]) });
let tokens_repo = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) });
let llm_service = LlmService::new();
let chat_service = ChatService::new(1024, messages, users_repo.clone(), llm_service);
let (client, jwt) = setup_client_with_svc(chat_service.clone(), users_repo.clone(), tokens_repo.clone()).await;
// Use the same client for sender but with a different user (or the same, doesn't matter for broadcast)
// Actually, to simulate another user, we should sign up another user.
let jwt_sender = {
let token_code = "valid-token-2";
tokens_repo.create_new(1, "test2", token_code, 1, Utc::now(), Utc::now() + chrono::Duration::days(1)).await.unwrap();
let signup_res = client.post("/api/signup")
.header(ContentType::JSON)
.body(json!({
"email": "test2@example.com",
"username": "testuser2",
"password": "password123",
"access_token": token_code
}).to_string())
.dispatch()
.await;
assert_eq!(signup_res.status(), Status::Ok);
let login_res = client.post("/api/login")
.header(ContentType::JSON)
.body(json!({
"username": "testuser2",
"password": "password123"
}).to_string())
.dispatch()
.await;
let body = login_res.into_string().await.unwrap();
let auth_resp: serde_json::Value = serde_json::from_str(&body).unwrap();
auth_resp["token"].as_str().unwrap().to_string()
};
let channel_id = 1;
// Start listening to the event stream
let mut response = client.get(format!("/api/events/{}", channel_id))
.header(Header::new("Authorization", format!("Bearer {}", jwt)))
.dispatch()
.await;
assert_eq!(response.status(), Status::Ok);
let num_messages = 5; // Reduced for faster debugging
let mut received_count = 0;
let jwt_clone = jwt.clone();
tokio::spawn(async move {
for i in 0..num_messages {
let msg = format!("Message {}", i);
let res = sender_client.post(format!("/api/chat/{}", channel_id))
.header(ContentType::JSON)
.header(Header::new("Authorization", format!("Bearer {}", jwt_clone)))
.body(json!({
"display_name": "testuser",
"user_id": 1,
"text": msg,
"timestamp": Utc::now()
}).to_string())
.dispatch()
.await;
assert_eq!(res.status(), Status::Ok);
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
}
});
// Wait a bit for messages to be posted
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
// Consume the stream
let text = response.into_string().await.unwrap();
println!("Received chunk: {}", text);
let mut received_count = 0;
for line in text.lines() {
if line.starts_with("data:") {
received_count += 1;
}
}
assert_eq!(received_count, num_messages, "Should receive all posted messages. Received: {}. Full text: {}", received_count, text);
}
+121
View File
@@ -0,0 +1,121 @@
use backend::rocket_builder;
use backend::repo::mock::{MockUserRepo, MockTokenRepo};
use backend::repo::message_repo::MessageRepository;
use backend::svc::chat_svc::ChatService;
use backend::repo::user_repo::UserRepository;
use backend::repo::{Repo, AccessTokenRepoTrait};
use rocket::local::asynchronous::Client;
use rocket::http::{Status, ContentType, Header};
use serde_json::json;
use std::sync::{Arc, Mutex};
use sqlx::PgPool;
use chrono::Utc;
use backend::svc::llm_service::LlmService;
async fn setup_client() -> (Client, Arc<MockUserRepo>, String) {
let users = Arc::new(MockUserRepo { users: Mutex::new(vec![]) });
let tokens = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) });
let pool = PgPool::connect_lazy("postgres://localhost/unused").unwrap();
let messages = MessageRepository::new(pool.clone());
let user_repo = Arc::new(UserRepository::new(pool));
let llm_service = LlmService::new();
let chat_service = ChatService::new(32, messages, user_repo, llm_service);
let client = Client::tracked(rocket_builder(users.clone(), tokens.clone(), chat_service)).await.expect("valid rocket instance");
// Create a user and get JWT
let token_code = "valid-token";
tokens.create_new(1, "test", token_code, 1, Utc::now(), Utc::now()).await.unwrap();
client.post("/api/signup")
.header(ContentType::JSON)
.body(json!({
"email": "test@example.com",
"username": "testuser",
"password": "password123",
"access_token": token_code
}).to_string())
.dispatch()
.await;
let login_res = client.post("/api/login")
.header(ContentType::JSON)
.body(json!({
"username": "testuser",
"password": "password123"
}).to_string())
.dispatch()
.await;
let auth_resp: serde_json::Value = serde_json::from_str(&login_res.into_string().await.unwrap()).unwrap();
let jwt = auth_resp["token"].as_str().unwrap().to_string();
(client, users, jwt)
}
#[rocket::async_test]
async fn test_change_display_name() {
let (client, users, jwt) = setup_client().await;
let response = client.patch("/api/settings/display_name")
.header(ContentType::JSON)
.header(Header::new("Authorization", format!("Bearer {}", jwt)))
.body(json!({
"display_name": "New Display Name"
}).to_string())
.dispatch()
.await;
assert_eq!(response.status(), Status::Ok);
let user = users.users.lock().unwrap()[0].clone();
assert_eq!(user.nickname, Some("New Display Name".to_string()));
}
#[rocket::async_test]
async fn test_change_username() {
let (client, users, jwt) = setup_client().await;
let response = client.patch("/api/settings/username")
.header(ContentType::JSON)
.header(Header::new("Authorization", format!("Bearer {}", jwt)))
.body(json!({
"username": "newusername"
}).to_string())
.dispatch()
.await;
assert_eq!(response.status(), Status::Ok);
let user = users.users.lock().unwrap()[0].clone();
assert_eq!(user.username, "newusername");
}
#[rocket::async_test]
async fn test_change_password() {
let (client, _, jwt) = setup_client().await;
let response = client.post("/api/settings/password")
.header(ContentType::JSON)
.header(Header::new("Authorization", format!("Bearer {}", jwt)))
.body(json!({
"old_password": "password123",
"new_password": "newpassword456"
}).to_string())
.dispatch()
.await;
assert_eq!(response.status(), Status::Ok);
// Verify login with new password
let login_res = client.post("/api/login")
.header(ContentType::JSON)
.body(json!({
"username": "testuser",
"password": "newpassword456"
}).to_string())
.dispatch()
.await;
assert_eq!(login_res.status(), Status::Ok);
}