3 Commits

Author SHA1 Message Date
zxq5 d33eee1281 frontend v0.4.0 2026-04-06 01:02:39 +01:00
zxq5 7c9b733813 updated gitignore 2026-04-06 01:00:27 +01:00
zxq5 bda1ef251a full backend rewrite.
calling this v0.4.0
2026-04-06 00:57:23 +01:00
92 changed files with 3769 additions and 1649 deletions
+3
View File
@@ -1,6 +1,9 @@
*/target
.env
.log*
Cargo.lock
.cargo/
.sqlx/
docker-compose*
+4 -1
View File
@@ -2,7 +2,10 @@
<module type="JAVA_MODULE" version="4">
<component name="NewModuleRootManager" inherit-compiler-output="true">
<exclude-output />
<content url="file://$MODULE_DIR$" />
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/backend/src" isTestSource="false" />
<excludeFolder url="file://$MODULE_DIR$/backend/target" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
+17
View File
@@ -0,0 +1,17 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
<data-source source="LOCAL" name="chatapp dev" uuid="81992477-fd6f-427e-a27e-7378c26db6ef">
<driver-ref>postgresql</driver-ref>
<synchronize>true</synchronize>
<jdbc-driver>org.postgresql.Driver</jdbc-driver>
<jdbc-url>jdbc:postgresql://100.118.108.58:5432/chatapp_dev</jdbc-url>
<jdbc-additional-properties>
<property name="com.intellij.clouds.kubernetes.db.host.port" />
<property name="com.intellij.clouds.kubernetes.db.enabled" value="false" />
<property name="com.intellij.clouds.kubernetes.db.container.port" />
</jdbc-additional-properties>
<working-dir>$ProjectFileDir$</working-dir>
</data-source>
</component>
</project>
+11
View File
@@ -0,0 +1,11 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="GradleMigrationSettings" migrationVersion="1" />
<component name="GradleSettings">
<option name="linkedExternalProjectsSettings">
<GradleProjectSettings>
<option name="externalProjectPath" value="$PROJECT_DIR$/android" />
</GradleProjectSettings>
</option>
</component>
</project>
+1
View File
@@ -1,4 +1,5 @@
<project version="4">
<component name="ExternalStorageConfigurationManager" enabled="true" />
<component name="ProjectRootManager" version="2" project-jdk-name="25" project-jdk-type="JavaSDK">
<output url="file://$PROJECT_DIR$/out" />
</component>
+3
View File
@@ -1,7 +1,9 @@
*.iml
.gradle
/local.properties
/keystore.properties
/.idea/caches
/.idea/.cache
/.idea/libraries
/.idea/modules.xml
/.idea/workspace.xml
@@ -13,3 +15,4 @@
.externalNativeBuild
.cxx
local.properties
release/
+8
View File
@@ -4,6 +4,14 @@
<selectionStates>
<SelectionState runConfigName="app">
<option name="selectionMode" value="DROPDOWN" />
<DropdownSelection timestamp="2026-04-02T14:33:39.814557661Z">
<Target type="DEFAULT_BOOT">
<handle>
<DeviceId pluginId="PhysicalDevice" identifier="serial=00319362N000094" />
</handle>
</Target>
</DropdownSelection>
<DialogSelection />
</SelectionState>
<SelectionState runConfigName="MainActivity">
<option name="selectionMode" value="DROPDOWN" />
+33 -2
View File
@@ -1,3 +1,5 @@
import java.util.Properties
plugins {
alias(libs.plugins.android.application)
alias(libs.plugins.kotlin.compose)
@@ -8,6 +10,25 @@ android {
namespace = "dev.zxq5.chatapp.android"
compileSdk = 35
val keystorePropertiesFile = rootProject.file("local.properties")
val keystoreProperties = Properties()
if (keystorePropertiesFile.exists()) {
keystoreProperties.load(keystorePropertiesFile.inputStream())
}
signingConfigs {
create("release") {
storeFile = file("${System.getProperty("user.home")}/keystores/chatapp.jks")
storePassword = keystoreProperties["KEYSTORE_PASSWORD"] as String?
?: System.getenv("KEYSTORE_PASSWORD")
?: ""
keyAlias = "chatapp"
keyPassword = keystoreProperties["KEY_PASSWORD"] as String?
?: System.getenv("KEY_PASSWORD")
?: ""
}
}
defaultConfig {
applicationId = "dev.zxq5.chatapp.android"
minSdk = 26
@@ -20,19 +41,30 @@ android {
buildTypes {
release {
isMinifyEnabled = false
isMinifyEnabled = true // shrinks code
isShrinkResources = true // removes unused resources
signingConfig = signingConfigs.getByName("release")
proguardFiles(
getDefaultProguardFile("proguard-android-optimize.txt"),
"proguard-rules.pro"
)
buildConfigField("String", "BASE_URL", "\"https://chat.zxq5.dev\"")
}
debug {
isMinifyEnabled = false
isDebuggable = true
applicationIdSuffix = ".debug" // lets you install both side by side
buildConfigField("String", "BASE_URL", "\"http://zxq5-x1:8000\"")
}
}
compileOptions {
sourceCompatibility = JavaVersion.VERSION_11
targetCompatibility = JavaVersion.VERSION_11
}
buildFeatures {
compose = true
buildConfig = true
}
}
@@ -44,7 +76,6 @@ dependencies {
implementation(libs.ktor.client.auth) // Auth plugin
// Kotlinx Serialization
implementation(libs.kotlinx.serialization.json)
// Coroutines
implementation(libs.kotlinx.coroutines.android)
+26
View File
@@ -19,3 +19,29 @@
# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile
# Ktor
-keep class io.ktor.** { *; }
-keep class kotlinx.coroutines.** { *; }
# Kotlinx serialization
-keepattributes *Annotation*, InnerClasses
-dontnote kotlinx.serialization.AnnotationsKt
-keep,includedescriptorclasses class dev.zxq5.chatapp.android.**$$serializer { *; }
-keepclassmembers class dev.zxq5.chatapp.android.** {
*** Companion;
}
-keepclasseswithmembers class dev.zxq5.chatapp.android.** {
kotlinx.serialization.KSerializer serializer(...);
}
# Keep model classes (serialization needs these)
-keep class dev.zxq5.chatapp.android.api.model.** { *; }
-keep class dev.zxq5.chatapp.android.data.model.** { *; }
# Fix for missing errorprone and javax annotations used by Tink and other libraries
-dontwarn com.google.errorprone.annotations.**
-dontwarn javax.annotation.**
# Fix for missing java.lang.management referenced by Ktor (not available on Android)
-dontwarn java.lang.management.**
+10
View File
@@ -3,6 +3,10 @@
xmlns:tools="http://schemas.android.com/tools">
<uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.POST_NOTIFICATIONS"/>
<uses-permission android:name="android.permission.FOREGROUND_SERVICE"/>
<uses-permission android:name="android.permission.FOREGROUND_SERVICE_DATA_SYNC"/>
<application
android:name=".ChatApplication"
@@ -15,6 +19,12 @@
android:supportsRtl="true"
android:theme="@style/Theme.Chatapp"
android:usesCleartextTraffic="true">
<service
android:name=".core.service.MessageStreamService"
android:foregroundServiceType="dataSync"
android:exported="false"/>
<activity
android:name=".MainActivity"
android:exported="true"
@@ -1,6 +1,9 @@
package dev.zxq5.chatapp.android
import android.app.Application
import android.app.NotificationChannel
import android.app.NotificationManager
import android.os.Build
import dev.zxq5.chatapp.android.core.data.TokenStore
import dev.zxq5.chatapp.android.data.repository.AuthRepository
import dev.zxq5.chatapp.android.data.repository.ChatRepository
@@ -8,6 +11,10 @@ import dev.zxq5.chatapp.android.data.repository.SettingsRepository
class ChatApplication : Application() {
object AppState {
var isInForeground = false
}
val tokenStore by lazy { TokenStore(this) }
val authRepository by lazy { AuthRepository(tokenStore) }
val chatRepository by lazy { ChatRepository(tokenStore) }
@@ -15,5 +22,30 @@ class ChatApplication : Application() {
override fun onCreate() {
super.onCreate()
createNotificationChannels()
}
private fun createNotificationChannels() {
val messageChannel = NotificationChannel(
"messages",
"Messages",
NotificationManager.IMPORTANCE_HIGH
).apply {
description = "New message notifications"
enableVibration(true)
}
// add this — required for the foreground service persistent notification
val serviceChannel = NotificationChannel(
"service",
"Background connection",
NotificationManager.IMPORTANCE_LOW
).apply {
description = "Keeps messages running in background"
}
val manager = getSystemService(NotificationManager::class.java)
manager.createNotificationChannel(messageChannel)
manager.createNotificationChannel(serviceChannel)
}
}
@@ -4,25 +4,42 @@ import android.os.Bundle
import androidx.activity.ComponentActivity
import androidx.activity.compose.setContent
import androidx.activity.enableEdgeToEdge
import androidx.compose.foundation.border
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.outlined.ChatBubbleOutline
import androidx.compose.material.icons.outlined.PeopleOutline
import androidx.compose.material.icons.outlined.Settings
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.NavigationBar
import androidx.compose.material3.NavigationBarItem
import androidx.compose.material3.NavigationBarItemDefaults
import androidx.compose.material3.Scaffold
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.unit.dp
import androidx.lifecycle.viewmodel.compose.viewModel
import dev.zxq5.chatapp.android.core.data.TokenStore
import dev.zxq5.chatapp.android.data.repository.AuthRepository
import dev.zxq5.chatapp.android.ChatApplication.AppState
import dev.zxq5.chatapp.android.core.service.MessageStreamService
import dev.zxq5.chatapp.android.data.repository.AuthState
import dev.zxq5.chatapp.android.data.repository.ChatRepository
import dev.zxq5.chatapp.android.data.repository.SettingsRepository
import dev.zxq5.chatapp.android.feature.auth.AuthScreen
import dev.zxq5.chatapp.android.feature.auth.AuthViewModel
import dev.zxq5.chatapp.android.feature.chat.ChatScreen
import dev.zxq5.chatapp.android.feature.chat.ChatViewModel
import dev.zxq5.chatapp.android.feature.chat.Screen
import dev.zxq5.chatapp.android.feature.settings.SettingsViewModel
import dev.zxq5.chatapp.android.feature.auth.AuthScreen
import dev.zxq5.chatapp.android.feature.chat.ChatScreen
import dev.zxq5.chatapp.android.feature.contacts.ContactsScreen
import dev.zxq5.chatapp.android.feature.settings.SettingsScreen
import dev.zxq5.chatapp.android.feature.settings.SettingsViewModel
import dev.zxq5.chatapp.android.ui.theme.ChatappTheme
class MainActivity : ComponentActivity() {
@@ -30,7 +47,6 @@ class MainActivity : ComponentActivity() {
super.onCreate(savedInstanceState)
val app = application as ChatApplication
val tokenStore = app.tokenStore
val authRepository = app.authRepository
val chatRepository = app.chatRepository
val settingsRepository = app.settingsRepository
@@ -44,11 +60,36 @@ class MainActivity : ComponentActivity() {
val authState by authViewModel.authState.collectAsState()
val currentScreen by chatViewModel.currentScreen.collectAsState()
val selectedChannelId by chatViewModel.channelId.collectAsState()
Scaffold(modifier = Modifier.fillMaxSize()) { innerPadding ->
androidx.compose.foundation.layout.Box(modifier = Modifier.padding(innerPadding)) {
LaunchedEffect(authState) {
when (authState) {
AuthState.Authenticated -> {
AuthState.Authenticated -> MessageStreamService.start(this@MainActivity)
AuthState.Unauthenticated -> MessageStreamService.stop(this@MainActivity)
AuthState.AwaitingTotp -> {}
}
}
LaunchedEffect(Unit) {
intent.getIntExtra("channel_id", -1).takeIf { it != -1 }?.let {
chatViewModel.switchChannel(it.toLong())
}
}
if (authState == AuthState.Authenticated) {
Scaffold(
modifier = Modifier.fillMaxSize(),
bottomBar = {
// Only show bottom bar if we are NOT inside a specific chat channel
if (selectedChannelId == null) {
BottomDock(
currentScreen = currentScreen,
onNavigate = { chatViewModel.navigateTo(it) }
)
}
}
) { innerPadding ->
Box(modifier = Modifier.padding(innerPadding)) {
when (currentScreen) {
Screen.CHAT -> ChatScreen(
viewModel = chatViewModel,
@@ -58,9 +99,9 @@ class MainActivity : ComponentActivity() {
chatViewModel.clearChat()
}
)
Screen.CONTACTS -> ContactsScreen()
Screen.SETTINGS -> SettingsScreen(
viewModel = settingsViewModel,
onBack = { chatViewModel.navigateTo(Screen.CHAT) },
onLogout = {
authViewModel.logout()
chatViewModel.clearChat()
@@ -68,13 +109,77 @@ class MainActivity : ComponentActivity() {
)
}
}
AuthState.AwaitingTotp, AuthState.Unauthenticated -> {
}
} else {
AuthScreen(viewModel = authViewModel)
}
}
}
}
override fun onResume() {
super.onResume()
AppState.isInForeground = true
}
override fun onPause() {
super.onPause()
AppState.isInForeground = false
}
override fun onNewIntent(intent: android.content.Intent) {
super.onNewIntent(intent)
intent.getIntExtra("channel_id", -1).takeIf { it != -1 }?.let { channelId ->
MessageStreamService.instance?.activeChannelId = channelId.toLong()
}
}
}
@Composable
fun BottomDock(currentScreen: Screen, onNavigate: (Screen) -> Unit) {
NavigationBar(
containerColor = MaterialTheme.colorScheme.background,
tonalElevation = 0.dp,
modifier = Modifier
.height(80.dp)
.border(
0.5.dp,
MaterialTheme.colorScheme.outlineVariant.copy(alpha = 0.2f),
RoundedCornerShape(topStart = 0.dp, topEnd = 0.dp)
)
) {
NavigationBarItem(
selected = currentScreen == Screen.CHAT,
onClick = { onNavigate(Screen.CHAT) },
icon = { Icon(Icons.Outlined.ChatBubbleOutline, contentDescription = "Chat") },
label = { Text("chat", style = MaterialTheme.typography.labelSmall) },
colors = NavigationBarItemDefaults.colors(
selectedIconColor = MaterialTheme.colorScheme.primary,
unselectedIconColor = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f),
indicatorColor = Color.Transparent
)
)
NavigationBarItem(
selected = currentScreen == Screen.CONTACTS,
onClick = { onNavigate(Screen.CONTACTS) },
icon = { Icon(Icons.Outlined.PeopleOutline, contentDescription = "Contacts") },
label = { Text("contacts", style = MaterialTheme.typography.labelSmall) },
colors = NavigationBarItemDefaults.colors(
selectedIconColor = MaterialTheme.colorScheme.primary,
unselectedIconColor = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f),
indicatorColor = Color.Transparent
)
)
NavigationBarItem(
selected = currentScreen == Screen.SETTINGS,
onClick = { onNavigate(Screen.SETTINGS) },
icon = { Icon(Icons.Outlined.Settings, contentDescription = "Settings") },
label = { Text("settings", style = MaterialTheme.typography.labelSmall) },
colors = NavigationBarItemDefaults.colors(
selectedIconColor = MaterialTheme.colorScheme.primary,
unselectedIconColor = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f),
indicatorColor = Color.Transparent
)
)
}
}
@@ -1,10 +1,11 @@
package dev.zxq5.chatapp.android.api
import android.util.Log
import dev.zxq5.chatapp.android.BuildConfig
import dev.zxq5.chatapp.android.BuildConfig.BASE_URL
import dev.zxq5.chatapp.android.api.model.LoginRequest
import dev.zxq5.chatapp.android.api.model.LoginResponse
import dev.zxq5.chatapp.android.api.model.TOTPSixDigitCode
import dev.zxq5.chatapp.android.core.BASE_URL
import dev.zxq5.chatapp.android.core.error.ApiResult
import dev.zxq5.chatapp.android.api.model.SignupRequest
import io.ktor.client.HttpClient
@@ -1,14 +1,17 @@
package dev.zxq5.chatapp.android.api
import dev.zxq5.chatapp.android.core.BASE_URL
import dev.zxq5.chatapp.android.BuildConfig.BASE_URL
import dev.zxq5.chatapp.android.api.model.Message
import dev.zxq5.chatapp.android.api.model.SendMessage
import dev.zxq5.chatapp.android.api.model.SpaceDto
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.engine.android.Android
import io.ktor.client.plugins.auth.Auth
import io.ktor.client.plugins.auth.providers.BearerTokens
import io.ktor.client.plugins.auth.providers.bearer
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
import io.ktor.client.request.get
import io.ktor.client.request.post
import io.ktor.client.request.prepareGet
import io.ktor.client.request.setBody
@@ -16,12 +19,15 @@ import io.ktor.client.statement.bodyAsChannel
import io.ktor.http.ContentType
import io.ktor.http.contentType
import io.ktor.serialization.kotlinx.json.json
import io.ktor.utils.io.readUTF8Line
import io.ktor.utils.io.readLine
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow
import kotlinx.serialization.json.Json
import kotlin.time.Clock
import kotlin.time.ExperimentalTime
class ChatClient(private val token: String) {
private val http = HttpClient(Android) {
install(ContentNegotiation) {
json(Json { ignoreUnknownKeys = true })
@@ -33,18 +39,21 @@ class ChatClient(private val token: String) {
}
}
suspend fun sendMessage(channelId: Int, userId: Int, text: String) {
suspend fun getAccessibleChannels(): List<SpaceDto> = http.get("${BASE_URL}/api/accessible_channels").body()
@OptIn(ExperimentalTime::class)
suspend fun sendMessage(channelId: Long, userId: Int, text: String) {
http.post("${BASE_URL}/api/chat/$channelId") {
contentType(ContentType.Application.Json)
setBody(SendMessage(user_id = userId, text = text, timestamp = System.currentTimeMillis()))
setBody(SendMessage(user_id = userId, text = text, timestamp = Clock.System.now()))
}
}
fun messageStream(channelId: Int): Flow<Message> = flow {
fun messageStream(channelId: Long): Flow<Message> = flow {
http.prepareGet("${BASE_URL}/api/events/$channelId").execute { response ->
val channel = response.bodyAsChannel()
while (!channel.isClosedForRead) {
val line = channel.readUTF8Line(256) ?: break
val line = channel.readLine() ?: break
if (line.startsWith("data:")) {
val json = line.removePrefix("data:").trim()
runCatching { Json.decodeFromString<Message>(json) }
@@ -54,4 +63,3 @@ class ChatClient(private val token: String) {
}
}
}
@@ -1,6 +1,7 @@
package dev.zxq5.chatapp.android.api
import android.util.Log
import dev.zxq5.chatapp.android.BuildConfig.BASE_URL
import dev.zxq5.chatapp.android.api.model.AccountDeleteRequest
import dev.zxq5.chatapp.android.api.model.DisplayNameRequest
import dev.zxq5.chatapp.android.api.model.PasswordChangeRequest
@@ -10,7 +11,6 @@ import dev.zxq5.chatapp.android.api.model.TotpStatus
import dev.zxq5.chatapp.android.api.model.UsernameRequest
import dev.zxq5.chatapp.android.api.model.TotpDeleteRequest
import dev.zxq5.chatapp.android.api.model.PasswordRequest
import dev.zxq5.chatapp.android.core.BASE_URL
import dev.zxq5.chatapp.android.core.error.ApiResult
import io.ktor.client.HttpClient
import io.ktor.client.call.body
@@ -0,0 +1,15 @@
package dev.zxq5.chatapp.android.api.model
import kotlinx.serialization.Serializable
import kotlin.time.ExperimentalTime
import kotlin.time.Instant
@Serializable
data class Channel @OptIn(ExperimentalTime::class) constructor(
val id: Long,
val name: String,
val description: String? = null,
val space_id: Long,
val created_at: Instant,
val updated_at: Instant
)
@@ -1,11 +1,13 @@
package dev.zxq5.chatapp.android.api.model
import kotlinx.serialization.Serializable
import kotlin.time.ExperimentalTime
import kotlin.time.Instant
@Serializable
data class Message(
data class Message @OptIn(ExperimentalTime::class) constructor(
val user_id: Int,
val display_name: String,
val text: String,
val timestamp: Long
val timestamp: Instant
)
@@ -1,10 +1,12 @@
package dev.zxq5.chatapp.android.api.model
import kotlinx.serialization.Serializable
import kotlin.time.ExperimentalTime
import kotlin.time.Instant
@Serializable
data class SendMessage(
data class SendMessage @OptIn(ExperimentalTime::class) constructor(
val user_id: Int,
val text: String,
val timestamp: Long
val timestamp: Instant
)
@@ -0,0 +1,28 @@
package dev.zxq5.chatapp.android.api.model
import kotlinx.serialization.Serializable
import kotlin.time.ExperimentalTime
import kotlin.time.Instant
@Serializable
data class Space @OptIn(ExperimentalTime::class) constructor(
val id: Long,
val name: String,
val description: String? = null,
val owner_id: Long,
val created_at: Instant,
val updated_at: Instant
)
@Serializable
data class SpaceDto @OptIn(ExperimentalTime::class) constructor(
val channels: List<Channel>,
val id: Long,
val name: String,
val description: String? = null,
val owner_id: Long,
val created_at: Instant,
val updated_at: Instant
)
@@ -1,3 +1,3 @@
package dev.zxq5.chatapp.android.core
const val BASE_URL = "http://zxq5-x1:8000"
//const val BASE_URL = "http://zxq5-x1:8000"
@@ -0,0 +1,104 @@
package dev.zxq5.chatapp.android.core.service
import android.app.Service
import android.content.Context
import android.content.Intent
import android.os.Build
import android.os.IBinder
import android.util.Log
import dev.zxq5.chatapp.android.ChatApplication
import dev.zxq5.chatapp.android.data.repository.ChatRepository
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.cancel
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.launch
// core/service/MessageStreamService.kt
class MessageStreamService : Service() {
private val serviceScope = CoroutineScope(SupervisorJob() + Dispatchers.IO)
private lateinit var notificationService: NotificationService
private lateinit var chatRepository: ChatRepository
// which channel the user is currently looking at
// set by the ViewModel when the user opens/closes a channel
var activeChannelId: Long? = null
set(value) {
field = value
Log.d("Service", "activeChannelId set to $value")
if (value != null) {
// restart stream with new channel
currentStreamJob?.cancel()
observeMessages()
}
}
private var currentStreamJob: kotlinx.coroutines.Job? = null
companion object {
var instance: MessageStreamService? = null
fun start(context: Context) {
val intent = Intent(context, MessageStreamService::class.java)
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
context.startForegroundService(intent)
} else {
context.startService(intent)
}
}
fun stop(context: Context) {
context.stopService(Intent(context, MessageStreamService::class.java))
}
}
override fun onCreate() {
super.onCreate()
instance = this
notificationService = NotificationService(this)
chatRepository = (application as ChatApplication).chatRepository
}
override fun onStartCommand(intent: Intent?, flags: Int, startId: Int): Int {
startForeground(
NotificationService.FOREGROUND_NOTIFICATION_ID,
notificationService.buildForegroundNotification()
)
observeMessages()
return START_STICKY // restart if killed
}
private fun observeMessages() {
val channelId = activeChannelId ?: chatRepository.getLastActiveChannel()
Log.d("Service", "observeMessages called, channelId=$channelId")
if (channelId == null) {
Log.d("Service", "No channel to observe, waiting for switchChannel")
return
}
Log.d("Service", "Starting stream for channel $channelId")
currentStreamJob = serviceScope.launch {
chatRepository.messageStream(channelId)
.catch { e -> Log.e("Service", "Stream error", e) }
.collect { message ->
if (!ChatApplication.AppState.isInForeground) { // no channel focused, always notify
notificationService.showMessageNotification(
conversationId = activeChannelId.toString(),
senderName = message.display_name,
messagePreview = message.text.take(80)
)
}
}
}
}
override fun onBind(intent: Intent?): IBinder? = null
override fun onDestroy() {
super.onDestroy()
instance = null
serviceScope.cancel()
}
}
@@ -0,0 +1,94 @@
package dev.zxq5.chatapp.android.core.service
import android.app.Notification
import android.app.NotificationChannel
import android.app.NotificationManager
import android.app.PendingIntent
import android.content.Context
import android.content.Intent
import android.os.Build
import androidx.core.app.NotificationCompat
import dev.zxq5.chatapp.android.MainActivity
import dev.zxq5.chatapp.android.R
class NotificationService(private val context: Context) {
companion object {
const val CHANNEL_ID = "messages"
const val FOREGROUND_NOTIFICATION_ID = 1 // ← this needs to exist
}
private val manager = context.getSystemService(NotificationManager::class.java)
fun createChannels() {
// channel for new message notifications
val messageChannel = NotificationChannel(
CHANNEL_ID,
"Messages",
NotificationManager.IMPORTANCE_HIGH
).apply {
enableVibration(true)
}
// channel for the persistent foreground service notification
// low importance so it doesn't make noise
val serviceChannel = NotificationChannel(
"service",
"Background connection",
NotificationManager.IMPORTANCE_LOW
)
val mgr = context.getSystemService(NotificationManager::class.java)
mgr.createNotificationChannel(messageChannel)
mgr.createNotificationChannel(serviceChannel)
}
fun buildForegroundNotification(): Notification {
return NotificationCompat.Builder(context, "service")
.setSmallIcon(R.drawable.ic_notification)
.setContentTitle("chatapp")
.setContentText("Connected")
.setOngoing(true)
.setSilent(true)
.build()
}
fun showMessageNotification(
conversationId: String,
senderName: String,
messagePreview: String, // for E2E this would be "New message" — no plaintext
notificationId: Int = conversationId.hashCode()
) {
// intent that opens the app to the right conversation when tapped
val intent = Intent(context, MainActivity::class.java).apply {
flags = Intent.FLAG_ACTIVITY_SINGLE_TOP
putExtra("conversation_id", conversationId)
}
val pendingIntent = PendingIntent.getActivity(
context,
notificationId,
intent,
PendingIntent.FLAG_UPDATE_CURRENT or PendingIntent.FLAG_IMMUTABLE
)
val notification = NotificationCompat.Builder(context, "messages")
.setSmallIcon(R.drawable.ic_notification)
.setContentTitle(senderName)
.setContentText(messagePreview)
.setPriority(NotificationCompat.PRIORITY_HIGH)
.setContentIntent(pendingIntent)
.setAutoCancel(true) // dismiss on tap
.build()
manager.notify(notificationId, notification)
}
fun dismissNotification(conversationId: String) {
manager.cancel(conversationId.hashCode())
}
fun dismissAll() {
manager.cancelAll()
}
}
@@ -3,6 +3,7 @@ package dev.zxq5.chatapp.android.data.repository
import dev.zxq5.chatapp.android.api.ChatClient
import dev.zxq5.chatapp.android.core.data.TokenStore
import dev.zxq5.chatapp.android.api.model.Message
import dev.zxq5.chatapp.android.api.model.SpaceDto
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.emptyFlow
@@ -11,6 +12,8 @@ class ChatRepository(private val tokenStore: TokenStore) {
private var _chatClient: ChatClient? = null
private var _lastToken: String? = null
private var _lastActiveChannel: Long? = null
private fun getChatClient(): ChatClient? {
val token = tokenStore.get() ?: return null
if (_chatClient == null || token != _lastToken) {
@@ -25,14 +28,23 @@ class ChatRepository(private val tokenStore: TokenStore) {
_lastToken = null
}
fun getLastActiveChannel(): Long? {
return _lastActiveChannel
}
fun getUserId() = tokenStore.getUserId()
suspend fun sendMessage(channelId: Int, text: String) {
suspend fun getAccessibleChannels(): List<SpaceDto> {
return getChatClient()?.getAccessibleChannels() ?: emptyList()
}
suspend fun sendMessage(channelId: Long, text: String) {
val userId = tokenStore.getUserId() ?: return
getChatClient()?.sendMessage(channelId, userId, text)
}
fun messageStream(channelId: Int): Flow<Message> {
fun messageStream(channelId: Long): Flow<Message> {
_lastActiveChannel = channelId
return getChatClient()?.messageStream(channelId) ?: emptyFlow()
}
}
@@ -2,6 +2,7 @@ package dev.zxq5.chatapp.android.feature.auth
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import dev.zxq5.chatapp.android.core.service.MessageStreamService
import dev.zxq5.chatapp.android.data.repository.AuthRepository
import dev.zxq5.chatapp.android.data.repository.LoginResult
import dev.zxq5.chatapp.android.data.repository.SignupResult
@@ -3,8 +3,12 @@ package dev.zxq5.chatapp.android.feature.chat
import android.util.Log
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import dev.zxq5.chatapp.android.api.model.Channel
import dev.zxq5.chatapp.android.data.repository.ChatRepository
import dev.zxq5.chatapp.android.api.model.Message
import dev.zxq5.chatapp.android.api.model.Space
import dev.zxq5.chatapp.android.api.model.SpaceDto
import dev.zxq5.chatapp.android.core.service.MessageStreamService
import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
@@ -12,15 +16,13 @@ import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch
class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
private val _messages = MutableStateFlow<List<Message>>(emptyList())
val messages: StateFlow<List<Message>> = _messages
private val _channelId = MutableStateFlow<Int?>(null)
val channelId: StateFlow<Int?> = _channelId
private val _channelId = MutableStateFlow<Long?>(null)
val channelId: StateFlow<Long?> = _channelId
private val _currentScreen = MutableStateFlow(Screen.CHAT)
val currentScreen: StateFlow<Screen> = _currentScreen
@@ -28,11 +30,35 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
private val _currentUserId = MutableStateFlow<Int?>(null)
val currentUserId: StateFlow<Int?> = _currentUserId
private val _spaces = MutableStateFlow<List<SpaceDto>>(emptyList())
val spaces: StateFlow<List<SpaceDto>> = _spaces
private val _error = MutableStateFlow<String?>(null)
val error: StateFlow<String?> = _error
private val _channelError = MutableStateFlow<String?>(null)
val channelError: StateFlow<String?> = _channelError
private var streamJob: Job? = null
init {
_currentUserId.value = chatRepository.getUserId()
observeChannel()
loadAccessibleChannels()
}
fun loadAccessibleChannels() {
_error.value = null
viewModelScope.launch {
runCatching {
chatRepository.getAccessibleChannels()
}.onSuccess { data ->
_spaces.value = data
}.onFailure { e ->
Log.e("Chat", "Failed to load spaces", e)
_error.value = "Failed to load channels: ${e.message}"
}
}
}
private fun observeChannel() {
@@ -40,11 +66,13 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
_channelId.collect { id ->
streamJob?.cancel()
_messages.value = emptyList()
_channelError.value = null
if (id != null) {
streamJob = launch {
chatRepository.messageStream(id)
.catch { e ->
Log.e("Chat", "Stream error", e)
_channelError.value = "Connection lost: ${e.message}"
}
.collect { message ->
_messages.update { it + message }
@@ -59,12 +87,14 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
_currentScreen.value = screen
}
fun switchChannel(id: Int?) {
fun switchChannel(id: Long?) {
_channelId.value = id
MessageStreamService.instance?.activeChannelId = id
if (id != null) {
// Refresh user ID just in case it wasn't available at init
_currentUserId.value = chatRepository.getUserId()
chatRepository.resetClient()
}
}
@@ -78,6 +108,7 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
)
}.onFailure { e ->
Log.e("Chat", "Send message error", e)
_channelError.value = "Failed to send message"
}
}
}
@@ -86,8 +117,14 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
_messages.value = emptyList()
_channelId.value = null
_currentUserId.value = null
_error.value = null
_channelError.value = null
streamJob?.cancel()
chatRepository.resetClient()
}
MessageStreamService.instance?.activeChannelId = null
}
fun clearChannelError() {
_channelError.value = null
}
}
@@ -1,5 +1,5 @@
package dev.zxq5.chatapp.android.feature.chat
enum class Screen {
CHAT, SETTINGS
CHAT, CONTACTS, SETTINGS
}
@@ -28,22 +28,21 @@ import androidx.compose.foundation.text.KeyboardOptions
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.filled.ArrowBack
import androidx.compose.material.icons.filled.Add
import androidx.compose.material.icons.filled.Refresh
import androidx.compose.material.icons.filled.Send
import androidx.compose.material3.Button
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.HorizontalDivider
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.NavigationBar
import androidx.compose.material3.NavigationBarItem
import androidx.compose.material3.NavigationBarItemDefaults
import androidx.compose.material3.Scaffold
import androidx.compose.material3.SnackbarHost
import androidx.compose.material3.SnackbarHostState
import androidx.compose.material3.Surface
import androidx.compose.material3.Text
import androidx.compose.material3.TopAppBar
import androidx.compose.material3.TopAppBarDefaults
import androidx.compose.material.icons.filled.Send
import androidx.compose.material.icons.outlined.ChatBubbleOutline
import androidx.compose.material.icons.outlined.Settings
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState
@@ -56,26 +55,29 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.SolidColor
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.input.ImeAction
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.unit.Dp
import androidx.compose.ui.unit.dp
import dev.zxq5.chatapp.android.api.model.Channel
import dev.zxq5.chatapp.android.api.model.Message
import java.text.DateFormat
import java.util.Date
import kotlin.time.ExperimentalTime
@Composable
fun ChatScreen(
viewModel: ChatViewModel,
onNavigateToSettings: () -> Unit,
onLogout: () -> Unit // Note: logout is now part of SettingsScreen in this UI, but we'll keep the param for now
onLogout: () -> Unit
) {
val selectedChannelId by viewModel.channelId.collectAsState()
if (selectedChannelId == null) {
ChannelListScreen(
viewModel = viewModel,
onChannelSelect = { viewModel.switchChannel(it) },
onNavigateToSettings = onNavigateToSettings
onChannelSelect = { viewModel.switchChannel(it) }
)
} else {
MessageScreen(
@@ -90,20 +92,15 @@ fun ChatScreen(
@Composable
fun ChannelListScreen(
viewModel: ChatViewModel,
onChannelSelect: (Int) -> Unit,
onNavigateToSettings: () -> Unit
onChannelSelect: (Long) -> Unit
) {
val spaces by viewModel.spaces.collectAsState()
val error by viewModel.error.collectAsState()
Scaffold(
containerColor = MaterialTheme.colorScheme.background,
topBar = {
Column {
Spacer(Modifier.height(8.dp))
Text(
"contacts",
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f),
modifier = Modifier.padding(horizontal = 20.dp)
)
TopAppBar(
title = {
Text(
@@ -115,103 +112,69 @@ fun ChannelListScreen(
colors = TopAppBarDefaults.topAppBarColors(
containerColor = Color.Transparent,
titleContentColor = MaterialTheme.colorScheme.onSurface
)
),
windowInsets = androidx.compose.foundation.layout.WindowInsets(0, 0, 0, 0),
)
Text(
"5 channels · end-to-end encrypted",
"Public channels - dms coming soon.",
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.4f),
modifier = Modifier.padding(horizontal = 20.dp, vertical = 2.dp)
)
Spacer(Modifier.height(12.dp))
Row(
modifier = Modifier
.fillMaxWidth()
.background(MaterialTheme.colorScheme.secondary.copy(alpha = 0.2f))
.padding(horizontal = 20.dp, vertical = 8.dp),
verticalAlignment = Alignment.CenterVertically
) {
Box(
modifier = Modifier
.size(6.dp)
.clip(CircleShape)
.background(MaterialTheme.colorScheme.primary)
)
Spacer(Modifier.width(10.dp))
Text(
"global · walkie talkie",
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onSurfaceVariant,
modifier = Modifier.weight(1f)
)
Surface(
color = Color.Transparent,
border = border(0.5.dp, MaterialTheme.colorScheme.outlineVariant),
shape = RoundedCornerShape(4.dp)
) {
Text(
"hold to talk",
modifier = Modifier.padding(horizontal = 8.dp, vertical = 3.dp),
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
}
},
bottomBar = { BottomDock(viewModel, onNavigateToSettings) }
) { padding ->
if (error != null) {
Column(
modifier = Modifier.fillMaxSize().padding(padding),
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.Center
) {
Text(
text = error!!,
color = MaterialTheme.colorScheme.error,
textAlign = TextAlign.Center,
modifier = Modifier.padding(16.dp)
)
Button(onClick = { viewModel.loadAccessibleChannels() }) {
Icon(Icons.Default.Refresh, contentDescription = null)
Spacer(Modifier.width(8.dp))
Text("Retry")
}
}
} else {
LazyColumn(modifier = Modifier.padding(padding).fillMaxSize()) {
items(10) { i ->
val id = i + 1
ChannelItem(id = id, onClick = { onChannelSelect(id) })
spaces.forEach { spaceDto ->
item {
Text(
text = spaceDto.name.lowercase(),
modifier = Modifier.padding(horizontal = 20.dp, vertical = 8.dp),
style = MaterialTheme.typography.labelMedium,
color = MaterialTheme.colorScheme.primary,
fontWeight = FontWeight.Bold
)
}
items(spaceDto.channels) { channel ->
ChannelItem(channel = channel, onClick = { onChannelSelect(channel.id) })
HorizontalDivider(
modifier = Modifier.padding(horizontal = 20.dp),
thickness = 0.5.dp,
color = MaterialTheme.colorScheme.outlineVariant.copy(alpha = 0.2f)
)
}
item {
Spacer(Modifier.height(16.dp))
}
}
}
}
}
}
@Composable
fun BottomDock(viewModel: ChatViewModel, onNavigateToSettings: () -> Unit) {
val currentScreen by viewModel.currentScreen.collectAsState()
NavigationBar(
containerColor = MaterialTheme.colorScheme.background,
tonalElevation = 0.dp,
modifier = Modifier.border(0.5.dp, MaterialTheme.colorScheme.outlineVariant.copy(alpha = 0.2f), RoundedCornerShape(topStart = 0.dp, topEnd = 0.dp))
) {
NavigationBarItem(
selected = currentScreen == Screen.CHAT,
onClick = { viewModel.navigateTo(Screen.CHAT) },
icon = { Icon(Icons.Outlined.ChatBubbleOutline, contentDescription = "Chat") },
label = { Text("chat", style = MaterialTheme.typography.labelSmall) },
colors = NavigationBarItemDefaults.colors(
selectedIconColor = MaterialTheme.colorScheme.primary,
unselectedIconColor = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f),
indicatorColor = Color.Transparent
)
)
NavigationBarItem(
selected = currentScreen == Screen.SETTINGS,
onClick = onNavigateToSettings,
icon = { Icon(Icons.Outlined.Settings, contentDescription = "Settings") },
label = { Text("settings", style = MaterialTheme.typography.labelSmall) },
colors = NavigationBarItemDefaults.colors(
selectedIconColor = MaterialTheme.colorScheme.primary,
unselectedIconColor = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f),
indicatorColor = Color.Transparent
)
)
}
}
@Composable
fun ChannelItem(id: Int, onClick: () -> Unit) {
fun ChannelItem(channel: Channel, onClick: () -> Unit) {
Row(
modifier = Modifier
.fillMaxWidth()
@@ -227,7 +190,7 @@ fun ChannelItem(id: Int, onClick: () -> Unit) {
contentAlignment = Alignment.Center
) {
Text(
"C$id",
channel.name.take(1).uppercase(),
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
@@ -235,31 +198,30 @@ fun ChannelItem(id: Int, onClick: () -> Unit) {
Spacer(Modifier.width(12.dp))
Column(modifier = Modifier.weight(1f)) {
Text(
text = "channel $id",
text = channel.name,
style = MaterialTheme.typography.bodyLarge,
color = MaterialTheme.colorScheme.onSurface
)
if (channel.description != null) {
Text(
text = "tap to join",
text = channel.description,
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f)
)
}
Text(
"14:22",
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.3f)
)
}
}
}
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun MessageScreen(channelId: Int, viewModel: ChatViewModel, onBack: () -> Unit) {
fun MessageScreen(channelId: Long, viewModel: ChatViewModel, onBack: () -> Unit) {
val messages by viewModel.messages.collectAsState()
val currentUserId by viewModel.currentUserId.collectAsState()
val channelError by viewModel.channelError.collectAsState()
var input by remember { mutableStateOf("") }
val listState = rememberLazyListState()
val snackbarHostState = remember { SnackbarHostState() }
LaunchedEffect(messages.size) {
if (messages.isNotEmpty()) {
@@ -267,8 +229,16 @@ fun MessageScreen(channelId: Int, viewModel: ChatViewModel, onBack: () -> Unit)
}
}
LaunchedEffect(channelError) {
channelError?.let {
snackbarHostState.showSnackbar(it)
viewModel.clearChannelError()
}
}
Scaffold(
containerColor = MaterialTheme.colorScheme.background,
snackbarHost = { SnackbarHost(snackbarHostState) },
topBar = {
TopAppBar(
title = {
@@ -294,6 +264,7 @@ fun MessageScreen(channelId: Int, viewModel: ChatViewModel, onBack: () -> Unit)
)
}
},
windowInsets = androidx.compose.foundation.layout.WindowInsets(0, 0, 0, 0),
colors = TopAppBarDefaults.topAppBarColors(
containerColor = Color.Transparent
)
@@ -391,10 +362,13 @@ fun MessageScreen(channelId: Int, viewModel: ChatViewModel, onBack: () -> Unit)
}
}
@OptIn(ExperimentalTime::class)
@Composable
fun MessageBubble(message: Message, currentUserId: Int?) {
val time = remember(message.timestamp) {
DateFormat.getTimeInstance(DateFormat.SHORT).format(Date(message.timestamp)).lowercase()
DateFormat.getTimeInstance(DateFormat.SHORT)
.format(Date(message.timestamp.toEpochMilliseconds()))
.lowercase()
}
val isMe = currentUserId != null && message.user_id == currentUserId
@@ -0,0 +1,52 @@
package dev.zxq5.chatapp.android.feature.contacts
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.padding
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Scaffold
import androidx.compose.material3.Text
import androidx.compose.material3.TopAppBar
import androidx.compose.material3.TopAppBarDefaults
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun ContactsScreen() {
Scaffold(
containerColor = MaterialTheme.colorScheme.background,
topBar = {
TopAppBar(
title = {
Text(
"contacts",
style = MaterialTheme.typography.titleLarge,
color = MaterialTheme.colorScheme.onSurface
)
},
windowInsets = androidx.compose.foundation.layout.WindowInsets(0, 0, 0, 0),
colors = TopAppBarDefaults.topAppBarColors(
containerColor = Color.Transparent,
titleContentColor = MaterialTheme.colorScheme.onSurface
)
)
}
) { padding ->
Box(
modifier = Modifier
.fillMaxSize()
.padding(padding),
contentAlignment = Alignment.Center
) {
Text(
text = "Contacts coming soon",
style = MaterialTheme.typography.bodyLarge,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
}
@@ -63,7 +63,6 @@ import androidx.compose.ui.text.style.TextAlign
@Composable
fun SettingsScreen(
viewModel: SettingsViewModel,
onBack: () -> Unit,
onLogout: () -> Unit
) {
val is2faEnabled by viewModel.is2faEnabled.collectAsState()
@@ -88,15 +87,7 @@ fun SettingsScreen(
color = MaterialTheme.colorScheme.onSurface
)
},
navigationIcon = {
IconButton(onClick = onBack) {
Icon(
Icons.AutoMirrored.Filled.ArrowBack,
contentDescription = "Back",
tint = MaterialTheme.colorScheme.onSurfaceVariant
)
}
},
windowInsets = androidx.compose.foundation.layout.WindowInsets(0, 0, 0, 0),
colors = TopAppBarDefaults.topAppBarColors(containerColor = Color.Transparent)
)
}
@@ -0,0 +1,5 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android" android:autoMirrored="true" android:height="24dp" android:tint="#FFFFFF" android:viewportHeight="24" android:viewportWidth="24" android:width="24dp">
<path android:fillColor="@android:color/white" android:pathData="M20,2L4,2c-1.1,0 -1.99,0.9 -1.99,2L2,22l4,-4h14c1.1,0 2,-0.9 2,-2L22,4c0,-1.1 -0.9,-2 -2,-2zM18,14L6,14v-2h12v2zM18,11L6,11L6,9h12v2zM18,8L6,8L6,6h12v2z"/>
</vector>
+10
View File
@@ -0,0 +1,10 @@
# Default ignored files
/shelf/
/workspace.xml
# Ignored default folder with query files
/queries/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml
# Editor-based HTTP Client requests
/httpRequests/
+12
View File
@@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="EMPTY_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/tests" isTestSource="true" />
<excludeFolder url="file://$MODULE_DIR$/target" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>
+17
View File
@@ -0,0 +1,17 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
<data-source source="LOCAL" name="chatapp_dev@100.118.108.58" uuid="b14acf5d-6750-469b-8aea-59c8343eb11c">
<driver-ref>postgresql</driver-ref>
<synchronize>true</synchronize>
<jdbc-driver>org.postgresql.Driver</jdbc-driver>
<jdbc-url>jdbc:postgresql://100.118.108.58:5432/chatapp_dev</jdbc-url>
<jdbc-additional-properties>
<property name="com.intellij.clouds.kubernetes.db.host.port" />
<property name="com.intellij.clouds.kubernetes.db.enabled" value="false" />
<property name="com.intellij.clouds.kubernetes.db.container.port" />
</jdbc-additional-properties>
<working-dir>$ProjectFileDir$</working-dir>
</data-source>
</component>
</project>
+7
View File
@@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DataSourcePerFileMappings">
<file url="file://$PROJECT_DIR$/sql/schema.sql" value="b14acf5d-6750-469b-8aea-59c8343eb11c" />
<file url="file://$PROJECT_DIR$/src/repo/user_repo.rs" value="b14acf5d-6750-469b-8aea-59c8343eb11c" />
</component>
</project>
+6
View File
@@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
</profile>
</component>
+8
View File
@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/backend.iml" filepath="$PROJECT_DIR$/.idea/backend.iml" />
</modules>
</component>
</project>
+7
View File
@@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="SqlDialectMappings">
<file url="file://$PROJECT_DIR$/sql/schema.sql" dialect="PostgreSQL" />
<file url="PROJECT" dialect="PostgreSQL" />
</component>
</project>
+6
View File
@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$/.." vcs="Git" />
</component>
</project>
+5 -2
View File
@@ -10,7 +10,7 @@ dotenv = "0.15.0"
futures-util = "0.3.31"
image = "0.25.8"
jsonwebtoken = { version = "10.3.0", features = ["rust_crypto"] }
rand = "0.9.2"
rand = "0.8"
redis = { version = "0.25.4", features = ["tokio-comp"] }
reqwest = { version = "0.12.23", features = ["json"] }
rocket = { version = "0.5.1", features = ["json", "secrets"] }
@@ -20,8 +20,11 @@ rocket_dyn_templates = { version = "0.2.0", features = ["tera"] }
serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0.145"
sha2 = "0.10.9"
sqlx = { version = "0.7.4", features = ["macros", "time"] }
sqlx = { version = "0.7.4", features = ["chrono", "macros", "postgres", "time"] }
tokio = { version = "1.47.1", features = ["full"] }
totp-rs = { version = "5.7.0", features = ["gen_secret", "qr", "rand"] }
tracing = "0.1.44"
uuid = { version = "1.18.1", features = ["v4"] }
thiserror = "1.0.69"
utoipa = { version = "5.4.0", features = ["rocket_extras", "chrono"] }
clap = { version = "4.5", features = ["derive"] }
@@ -1,49 +0,0 @@
-- Add migration script here
CREATE TABLE users (
id SERIAL PRIMARY KEY,
username VARCHAR(50) UNIQUE NOT NULL,
password VARCHAR(50) NOT NULL,
display_name VARCHAR(50),
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE channels (
id SERIAL PRIMARY KEY,
name VARCHAR(50) NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE messages (
id SERIAL PRIMARY KEY,
channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
content TEXT NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
is_edited BOOLEAN DEFAULT FALSE
);
create table attachments (
id SERIAL PRIMARY KEY,
message_id INTEGER NOT NULL REFERENCES messages(id) ON DELETE CASCADE,
path TEXT NOT NULL
);
CREATE INDEX idx_messages_channel_id ON messages (channel_id, id DESC);
CREATE INDEX idx_new_messages ON messages(created_at DESC);
-- Create a function to update the updated_at timestamp
CREATE OR REPLACE FUNCTION update_updated_at_column()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = CURRENT_TIMESTAMP;
RETURN NEW;
END;
$$ language 'plpgsql';
-- Create trigger for messages table
CREATE TRIGGER update_messages_updated_at
BEFORE UPDATE ON messages
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
@@ -1,20 +0,0 @@
-- Add migration script here
CREATE TABLE sessions (
id SERIAL PRIMARY KEY,
user_id INTEGER NOT NULL REFERENCES users(id),
token TEXT NOT NULL UNIQUE,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + INTERVAL '7 days'
);
CREATE OR REPLACE FUNCTION cleanup_expired_sessions()
RETURNS TRIGGER AS $$
BEGIN
DELETE FROM sessions WHERE expires_at < NOW();
RETURN NULL;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER trigger_cleanup_sessions
AFTER INSERT ON sessions
EXECUTE FUNCTION cleanup_expired_sessions();
@@ -1,5 +0,0 @@
-- Add migration script here
ALTER TABLE users ADD COLUMN email VARCHAR(100);
ALTER TABLE users ADD COLUMN twofa_enabled BOOLEAN DEFAULT FALSE;
ALTER TABLE users ADD COLUMN totp_secret VARCHAR(64);
ALTER TABLE users ADD COLUMN updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP;
@@ -1,2 +0,0 @@
-- Add migration script here
ALTER TABLE users ALTER COLUMN twofa_enabled SET NOT NULL;
@@ -1,18 +0,0 @@
-- Add migration script here
CREATE TABLE access_codes (
-- identifiers
id SERIAL PRIMARY KEY,
creator_id INTEGER NOT NULL REFERENCES users(id),
-- code data
code VARCHAR(255) NOT NULL,
name VARCHAR(255) NOT NULL,
-- uses
uses INTEGER NOT NULL DEFAULT 0,
max_uses INTEGER NOT NULL DEFAULT 1,
-- time data
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + INTERVAL '1 day'
);
@@ -1,10 +0,0 @@
-- Add migration script here
ALTER TABLE access_codes
ALTER COLUMN created_at
TYPE TIMESTAMP WITH TIME ZONE
USING created_at AT TIME ZONE 'UTC';
ALTER TABLE access_codes
ALTER COLUMN expires_at
TYPE TIMESTAMP WITH TIME ZONE
USING expires_at AT TIME ZONE 'UTC';
@@ -1,6 +0,0 @@
-- Add migration script here
TRUNCATE TABLE users CASCADE;
ALTER TABLE users DROP COLUMN password;
ALTER TABLE users ADD COLUMN pass_hash VARCHAR(255) NOT NULL;
ALTER TABLE users ADD CONSTRAINT users_username_unique UNIQUE (username);
@@ -1,13 +0,0 @@
-- Add migration script here
CREATE TYPE status AS ENUM ('pending', 'accepted', 'blocked');
CREATE TABLE relationships (
id SERIAL PRIMARY KEY,
from_user INTEGER NOT NULL REFERENCES users(id),
to_user INTEGER NOT NULL REFERENCES users(id),
status status NOT NULL DEFAULT 'pending',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT no_self_relationship CHECK (from_user != to_user),
CONSTRAINT unique_relationship UNIQUE (from_user, to_user)
);
@@ -0,0 +1,183 @@
-- Add migration script here
-- Add migration script here
CREATE TYPE user_role AS ENUM ('user', 'admin');
CREATE TYPE totp_status AS ENUM ('disabled', 'pending', 'enabled');
CREATE TABLE users (
id BIGSERIAL PRIMARY KEY NOT NULL,
role user_role NOT NULL DEFAULT 'user',
-- profile
nickname VARCHAR(255),
-- basic auth
username VARCHAR(255) UNIQUE NOT NULL,
passhash VARCHAR(255) NOT NULL,
-- email
email VARCHAR(255),
email_verified BOOLEAN DEFAULT FALSE,
-- 2fa
totp_secret VARCHAR(255),
totp_status totp_status NOT NULL DEFAULT 'disabled',
-- update tracking
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
deleted_at TIMESTAMPTZ
);
CREATE TABLE access_tokens (
id BIGSERIAL PRIMARY KEY NOT NULL,
creator_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
code VARCHAR(255) NOT NULL,
name VARCHAR(255) NOT NULL,
uses INTEGER NOT NULL DEFAULT 0,
max_uses INTEGER NOT NULL DEFAULT 1,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
expires_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + INTERVAL '24 hours',
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE refresh_tokens (
id BIGSERIAL PRIMARY KEY NOT NULL,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
token_hash VARCHAR(255) NOT NULL,
revoked BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
expires_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + INTERVAL '7 days'
);
CREATE TABLE spaces (
id BIGSERIAL PRIMARY KEY NOT NULL,
name VARCHAR(255) NOT NULL,
description TEXT,
owner_id BIGINT NOT NULL REFERENCES users(id),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE channels (
id BIGSERIAL PRIMARY KEY NOT NULL,
name VARCHAR(255) NOT NULL,
description TEXT,
space_id BIGINT NOT NULL REFERENCES spaces(id) ON DELETE CASCADE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE space_members (
space_id BIGINT NOT NULL REFERENCES spaces(id) ON DELETE CASCADE,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
joined_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
role user_role DEFAULT 'user',
PRIMARY KEY (space_id, user_id)
);
CREATE TABLE messages (
id BIGSERIAL PRIMARY KEY NOT NULL,
channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
content TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
is_edited BOOLEAN DEFAULT FALSE
);
CREATE TABLE attachments (
id BIGSERIAL PRIMARY KEY NOT NULL,
message_id BIGINT NOT NULL REFERENCES messages(id) ON DELETE CASCADE,
filename VARCHAR(255) NOT NULL,
content_type VARCHAR(100) NOT NULL, -- mime type e.g. image/png, video/mp4
size_bytes BIGINT NOT NULL,
url TEXT NOT NULL, -- path to file on your CDN/storage
width INTEGER, -- null for non-image/video
height INTEGER, -- null for non-image/video
duration_ms INTEGER, -- null for non-audio/video
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TYPE relationship_status AS ENUM ('pending', 'accepted', 'blocked');
CREATE TABLE relationships (
id BIGSERIAL PRIMARY KEY NOT NULL,
from_user BIGINT NOT NULL REFERENCES users(id),
to_user BIGINT NOT NULL REFERENCES users(id),
status relationship_status NOT NULL DEFAULT 'pending',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT no_self_relationship CHECK (from_user != to_user),
CONSTRAINT unique_relationship UNIQUE (from_user, to_user)
);
CREATE INDEX idx_messages_channel_id ON messages(channel_id, created_at DESC);
CREATE INDEX idx_messages_user_id ON messages(user_id);
CREATE INDEX idx_attachments_message ON attachments(message_id);
CREATE INDEX idx_channels_space_id ON channels(space_id);
CREATE INDEX idx_space_members_user ON space_members(user_id);
CREATE INDEX idx_refresh_tokens_hash ON refresh_tokens(token_hash);
CREATE INDEX idx_relationships_from ON relationships(from_user, to_user);
CREATE INDEX idx_relationships_to ON relationships(to_user);
CREATE INDEX idx_access_tokens_code ON access_tokens(code);
CREATE OR REPLACE FUNCTION update_updated_at()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = NOW();
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER users_updated_at
BEFORE UPDATE ON users
FOR EACH ROW EXECUTE FUNCTION update_updated_at();
CREATE TRIGGER spaces_updated_at
BEFORE UPDATE ON spaces
FOR EACH ROW EXECUTE FUNCTION update_updated_at();
CREATE TRIGGER channels_updated_at
BEFORE UPDATE ON channels
FOR EACH ROW EXECUTE FUNCTION update_updated_at();
CREATE TRIGGER messages_updated_at
BEFORE UPDATE ON messages
FOR EACH ROW EXECUTE FUNCTION update_updated_at();
CREATE TRIGGER relationships_updated_at
BEFORE UPDATE ON relationships
FOR EACH ROW EXECUTE FUNCTION update_updated_at();
CREATE TRIGGER access_tokens_updated_at
BEFORE UPDATE ON access_tokens
FOR EACH ROW EXECUTE FUNCTION update_updated_at();
CREATE OR REPLACE FUNCTION add_owner_to_space()
RETURNS TRIGGER AS $$
BEGIN
INSERT INTO space_members (space_id, user_id, role, joined_at)
VALUES (NEW.id, NEW.owner_id, 'admin', NOW());
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER space_owner_becomes_member
AFTER INSERT ON spaces
FOR EACH ROW EXECUTE FUNCTION add_owner_to_space();
@@ -1,21 +1,46 @@
use std::{
sync::LazyLock,
time::{SystemTime, UNIX_EPOCH},
};
use crate::error::ApiResult;
use crate::model::auth::{AccessTokenForm, AuthResponse, LoginCredentials, SignupCredentials};
use crate::svc::auth_svc::AuthService;
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use rocket::http::Status;
use rocket::request::{FromRequest, Outcome};
use rocket::serde::json::Json;
use rocket::serde::{Deserialize, Serialize};
use rocket::{Request, State};
use std::sync::LazyLock;
use std::time::{SystemTime, UNIX_EPOCH};
use crate::svc::access_token_svc::AccessTokenService;
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation, decode, encode};
use rand::Rng;
use rocket::{
Request,
http::Status,
request::{self, FromRequest, Outcome},
};
use rocket_db_pools::Connection;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256, digest::block_buffer::Lazy};
use sqlx::postgres::PgQueryResult;
#[post("/signup", data = "<cred>")]
pub async fn signup(
cred: Json<SignupCredentials>,
svc: &State<AuthService>
) -> ApiResult<Json<AuthResponse>> {
let response = svc
.signup(
&cred.email, &cred.username, &cred.password, &cred.access_token,
).await?;
Ok(Json(response))
}
use crate::db::Postgres;
#[post("/login", data = "<cred>")]
pub async fn login(
cred: Json<LoginCredentials>,
svc: &State<AuthService>
) -> ApiResult<Json<AuthResponse>> {
Ok(Json(svc.login(&cred.username, &cred.password).await?))
}
#[post("/invite", data = "<form>")]
pub async fn generate_invite(
session: Session,
form: Json<AccessTokenForm>,
svc: &State<AccessTokenService>
) -> ApiResult<String> {
svc.create(
session.uid, &form.name, form.max_uses,
form.start_date, form.expiry_date).await
}
static JWT_SECRET: LazyLock<String> = LazyLock::new(|| std::env::var("JWT_SECRET").unwrap());
@@ -27,7 +52,7 @@ pub enum TokenScope {
}
pub struct Session {
pub user_id: usize,
pub uid: i64,
}
#[rocket::async_trait]
@@ -37,7 +62,7 @@ impl<'r> FromRequest<'r> for Session {
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match Claims::from_request(req).await {
Outcome::Success(user) if user.scope == TokenScope::Full => Outcome::Success(Session {
user_id: user.sub as usize,
uid: user.sub as i64,
}),
Outcome::Success(_) => {
eprintln!("warning: user with scope other than Full attempted to access session");
@@ -26,7 +26,7 @@ pub async fn profile_pic(user_id: usize) -> Option<NamedFile> {
Some(image)
} else {
Some(
NamedFile::open("./cdn/profiles/full/default.svg")
NamedFile::open("../../cdn/profiles/full/default.svg")
.await
.ok()?,
)
+70
View File
@@ -0,0 +1,70 @@
use crate::api::auth::Session;
use crate::error::ApiResult;
use crate::svc::chat_svc::ChatService;
use chrono::{DateTime, Utc};
use rocket::response::stream::Event;
use rocket::serde::json::Json;
use rocket::serde::{Deserialize, Serialize};
use rocket::{Shutdown, State, ___internal_EventStream as EventStream};
use sqlx::FromRow;
use tokio::select;
use tokio::sync::broadcast;
/// ---------- Rocket routes ----------
#[derive(Debug, Serialize, Deserialize, Clone, FromRow)]
pub struct ChatMsg {
pub display_name: Option<String>,
pub user_id: i64,
pub text: String,
pub timestamp: DateTime<Utc>,
}
#[post("/chat/<channel_id>", format = "json", data = "<msg>")]
pub async fn post_message(
msg: Json<ChatMsg>,
chat: &State<ChatService>,
session: Session,
channel_id: i64,
) -> ApiResult<()> {
chat.send(channel_id, session.uid, &msg.text, Utc::now()).await
}
#[get("/events/<channel_id>")]
pub async fn event_stream(
chat: &State<ChatService>,
s: Session,
mut shutdown: Shutdown,
channel_id: i64,
) -> ApiResult<EventStream![]> {
let messages = chat.get_messages(channel_id, 100)
.await?; // if get message returned err, inform user.
let mut rx = chat.subscribe(channel_id).await;
let id = s.uid;
Ok(EventStream! {
for msg in messages {
yield Event::json(&msg);
}
loop {
select!{
_ = &mut shutdown => break, // exit early on shutdown
msg = rx.recv() => match msg {
Ok(msg) => {
tracing::info!("yielding message!");
yield Event::json(&msg)
},
Err(broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!("Receiver lagging on channel {channel_id} by {n} events",);
yield Event::comment("RecvError::Lagged");
}
Err(broadcast::error::RecvError::Closed) => {
tracing::info!("Broadcaster hung up on channel {channel_id}!");
break
},
},
}
}
})
}
+7
View File
@@ -0,0 +1,7 @@
pub mod auth;
pub mod chat;
pub mod totp;
pub mod settings;
pub mod cdn;
pub mod profile;
pub mod space;
+13
View File
@@ -0,0 +1,13 @@
use rocket::State;
use crate::api::auth::Session;
use crate::error::ApiResult;
use crate::svc::user_svc::UserService;
#[get("/users/<id>")]
pub async fn display_name(
id: i64,
_ag: Session,
svc: &State<UserService>,
) -> ApiResult<String> {
svc.get_username(id).await
}
+68
View File
@@ -0,0 +1,68 @@
use crate::api::auth::Session;
use crate::error::ApiResult;
use crate::svc::settings_svc::SettingsService;
use rocket::serde::json::Json;
use rocket::serde::{Deserialize, Serialize};
use rocket::State;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PasswordForm {
pub old_password: String,
pub new_password: String,
}
#[post("/settings/password", data = "<form>")]
pub async fn change_password(
session: Session,
form: Json<PasswordForm>,
settings: &State<SettingsService>
) -> ApiResult<()> {
settings.change_password(
session.uid, &form.old_password, &form.new_password
).await
}
#[derive(Deserialize, Debug, Clone)]
pub struct DisplayNameForm {
pub display_name: Option<String>,
}
#[derive(Deserialize, Debug, Clone)]
pub struct PasswordAnd2faForm {
pub password: String,
pub totp_code: Option<String>,
}
#[delete("/settings", data = "<data>")]
pub async fn delete_account(
session: Session,
data: Json<PasswordAnd2faForm>,
settings: &State<SettingsService>
) -> ApiResult<()> {
settings.delete_account(
session.uid, &data.password, &data.totp_code
).await
}
#[patch("/settings/display_name", data = "<new>")]
pub async fn change_display_name(
session: Session,
new: Json<DisplayNameForm>,
settings: &State<SettingsService>
) -> ApiResult<()> {
settings.change_display_name(session.uid, new.display_name.clone()).await
}
#[derive(Deserialize)]
pub struct UsernameForm {
pub username: String,
}
#[patch("/settings/username", data = "<new>")]
pub async fn change_username(
session: Session,
new: Json<UsernameForm>,
settings: &State<SettingsService>
) -> ApiResult<()> {
settings.change_username(session.uid, &new.username).await
}
+36
View File
@@ -0,0 +1,36 @@
use crate::error::ApiResult;
use crate::model::space::{Space, SpaceDto};
use crate::model::space::Channel;
use crate::repo::{SpaceRepo, ChannelRepo};
use rocket::serde::json::Json;
use rocket::State;
use std::sync::Arc;
use crate::api::auth::Session;
use crate::svc::chat_svc::ChatService;
#[get("/spaces")]
pub async fn list_spaces(
space_repo: &State<Arc<dyn SpaceRepo>>
) -> ApiResult<Json<Vec<Space>>> {
let spaces = space_repo.get_all().await?;
Ok(Json(spaces))
}
#[get("/spaces/<space_id>/channels")]
pub async fn list_channels(
space_id: i64,
channel_repo: &State<Arc<dyn ChannelRepo>>
) -> ApiResult<Json<Vec<Channel>>> {
let channels = channel_repo.get_by_space_id(space_id).await?;
Ok(Json(channels))
}
#[get("/accessible_channels")]
pub async fn get_accessible_channels(
session: Session,
svc: &State<ChatService>
) -> ApiResult<Json<Vec<SpaceDto>>> {
let space = svc.get_accessible_channels(session.uid).await?;
println!("{:?}", space);
Ok(Json(space))
}
+120
View File
@@ -0,0 +1,120 @@
use crate::api::auth::{Claims, Session, TokenScope};
use crate::error::{ApiResult, AppError};
use crate::model::auth::AuthResponse;
use crate::svc::auth_svc::AuthService;
use rocket::serde::json::Json;
use rocket::serde::{Deserialize, Serialize};
use rocket::State;
use totp_rs::{Algorithm, TOTP};
#[derive(Debug, Deserialize)]
pub struct TOTPSixDigitCode {
code: String,
}
#[derive(Debug, sqlx::Type, Clone, Copy, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
#[sqlx(type_name = "totp_status", rename_all = "lowercase")]
pub enum TotpStatus {
Enabled,
Pending,
Disabled,
}
#[derive(Serialize)]
pub struct QrResponse {
qr_code: String,
}
#[derive(Deserialize)]
pub struct TotpVerifyRequest {
pub code: String,
}
#[derive(Deserialize)]
pub struct PasswordConfirmation {
password: String,
}
#[derive(Deserialize)]
pub struct PasswordAnd2fa {
pub password: String,
pub totp_code: String,
}
pub fn totp_gen(user_id: i64, secret: &[u8]) -> ApiResult<TOTP> {
TOTP::new(
Algorithm::SHA1,
6,
1,
30,
secret.to_owned(),
Some("chat.zxq5.dev".to_string()),
format!("{}", user_id),
)
.map_err(|_| AppError::internal("failed to generate totp"))
}
#[post("/totp", data = "<form>")]
pub async fn confirm_totp(
user: Session,
form: Json<TOTPSixDigitCode>,
svc: &State<AuthService>,
) -> ApiResult<()> {
svc.confirm_totp(user.uid, &form.code).await
}
#[post("/totp.jpg", data = "<form>")]
pub async fn get_totp(
user: Session,
form: Json<PasswordConfirmation>,
svc: &State<AuthService>,
) -> ApiResult<Json<QrResponse>> {
let secret = svc.get_or_create_totp_secret(user.uid, &form.password).await?;
let qr_b64 = totp_gen(user.uid, secret.as_bytes())
.map_err(|_| AppError::internal("invalid totp secret"))?
.get_qr_base64()
.map_err(|_| AppError::internal("failed to generate qr code"))?;
Ok(Json(QrResponse {
qr_code: format!("data:image/png;base64,{}", qr_b64),
}))
}
#[get("/totp/status")]
pub async fn get_totp_status(
user: Session,
svc: &State<AuthService>,
) -> ApiResult<Json<TotpStatus>> {
Ok(Json(
svc.get_totp_status(user.uid).await?
.then_some(TotpStatus::Enabled)
.unwrap_or(TotpStatus::Disabled),
))
}
#[delete("/totp", data = "<form>")]
pub async fn disable_totp(
user: Session,
form: Json<PasswordAnd2fa>,
svc: &State<AuthService>,
) -> ApiResult<Json<AuthResponse>> {
let response = svc.disable_totp(user.uid, &form.password, &form.totp_code).await?;
Ok(Json(response))
}
#[post("/totp/verify", data = "<body>")]
pub async fn verify_totp(
claims: Claims,
body: Json<TotpVerifyRequest>,
svc: &State<AuthService>,
) -> ApiResult<Json<AuthResponse>> {
// reject if they somehow got here with a full token
if claims.scope != TokenScope::TotpPending {
return Err(AppError::Forbidden);
}
let response = svc.login_totp(claims.sub as i64, &body.code).await?;
Ok(Json(response))
}
-211
View File
@@ -1,211 +0,0 @@
use argon2::{
Argon2, PasswordHash, PasswordHasher, PasswordVerifier,
password_hash::{SaltString, rand_core::OsRng},
};
use jsonwebtoken::{EncodingKey, Header, encode};
use rocket::{
http::{CookieJar, Status},
response::{Redirect, status::BadRequest},
serde::json::Json,
time::OffsetDateTime,
};
use rocket_db_pools::Connection;
use rocket_dyn_templates::{Template, context};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::{
auth::session::{Claims, Session, TokenScope},
db::Postgres,
user::User,
};
#[derive(Serialize, Deserialize)]
pub struct SignupCredentials {
pub email: String,
pub username: String,
pub password: String,
pub access_token: String,
}
#[derive(Serialize, Deserialize)]
pub struct LoginCredentials {
pub username: String,
pub password: String,
}
#[derive(Serialize, Deserialize)]
pub struct AuthResponse {
pub token: String,
pub totp_required: bool,
}
#[get("/signup")]
pub async fn signup_page() -> Template {
Template::render("signup", context!())
}
#[post("/signup", data = "<cred>")]
pub async fn signup(
cred: Json<SignupCredentials>,
jar: &CookieJar<'_>,
mut db: Connection<Postgres>,
) -> Result<Json<AuthResponse>, Status> {
let token_id = AccessToken::validate(&cred.access_token, &mut db)
.await
.map_err(|_| Status::Unauthorized)?;
let salt = SaltString::generate(&mut OsRng);
let hashed = Argon2::default()
.hash_password(cred.password.as_bytes(), &salt)
.map_err(|_| Status::InternalServerError)?
.to_string();
let result = sqlx::query!(
"INSERT INTO users (email, username, pass_hash) VALUES ($1, $2, $3) RETURNING id",
cred.email,
cred.username,
hashed,
)
.fetch_one(&mut **db)
.await
.map_err(|_| Status::InternalServerError)?;
let jwt = Claims::new(result.id as usize, TokenScope::Full).encode();
token_id
.use_token(&mut db)
.await
.expect("unable to use access code");
Ok(Json(AuthResponse {
token: jwt,
totp_required: false,
}))
}
#[get("/login")]
pub async fn login_page() -> Template {
Template::render("login", context!())
}
#[post("/login", data = "<cred>")]
pub async fn login(
mut db: Connection<Postgres>,
cred: Json<LoginCredentials>,
) -> Result<Json<AuthResponse>, Status> {
println!("e");
let row = sqlx::query!(
"SELECT id, pass_hash, twofa_enabled FROM users WHERE username = $1",
cred.username,
)
.fetch_one(&mut **db)
.await
.map_err(|_| Status::Unauthorized)?;
println!("ok");
// verify password as before
let parsed_hash = PasswordHash::new(&row.pass_hash).map_err(|_| Status::InternalServerError)?;
Argon2::default()
.verify_password(cred.password.as_bytes(), &parsed_hash)
.map_err(|_| Status::Unauthorized)?;
println!("ok2");
let user_id = row.id as usize;
// issue either a partial or full token depending on 2FA status
let (session, totp_required) = if row.twofa_enabled {
(Claims::new(user_id, TokenScope::TotpPending), true)
} else {
(Claims::new(user_id, TokenScope::Full), false)
};
Ok(Json(AuthResponse {
token: session.encode(),
totp_required,
}))
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessTokenForm {
pub name: String,
pub max_uses: usize,
pub expiry_date: usize,
pub start_date: usize,
}
#[get("/invite")]
pub async fn invite_page(_s: Session) -> Template {
Template::render("invite", context! {})
}
#[post("/invite", data = "<form>")]
pub async fn generate_invite(
session: Session,
mut db: Connection<Postgres>,
form: Json<AccessTokenForm>,
) -> Result<String, Status> {
if form.start_date > form.expiry_date {
return Err(Status::BadRequest);
}
let code = Uuid::new_v4().to_string();
sqlx::query!(
"INSERT INTO access_codes (name, code, creator_id, max_uses, created_at, expires_at)
VALUES ($1, $2, $3, $4, $5, $6) RETURNING id",
form.name,
code,
session.user_id as i32,
form.max_uses as i32,
OffsetDateTime::from_unix_timestamp_nanos(form.start_date as i128 * 1_000_000).unwrap(),
OffsetDateTime::from_unix_timestamp_nanos(form.expiry_date as i128 * 1_000_000).unwrap()
)
.fetch_one(&mut **db)
.await
.map_err(|_| Status::InternalServerError)?;
Ok(code)
}
pub struct AccessToken {
id: i32,
_code: String,
}
impl AccessToken {
pub async fn validate(
token: &str,
db: &mut Connection<Postgres>,
) -> Result<AccessToken, String> {
match sqlx::query!(
"SELECT id FROM access_codes
WHERE code = $1
AND created_at < NOW()
AND expires_at > NOW()
AND uses < max_uses",
token
)
.fetch_one(&mut ***db)
.await
{
Ok(row) => Ok(AccessToken {
id: row.id,
_code: token.to_string(),
}),
Err(_) => Err(String::from("Invalid or Expired token!")),
}
}
pub async fn use_token(&self, db: &mut Connection<Postgres>) -> Result<(), String> {
sqlx::query!(
"UPDATE access_codes SET uses = uses + 1 WHERE id = $1",
self.id
)
.execute(&mut ***db)
.await
.map_err(|_| String::from("Invalid or Expired token!"))?;
Ok(())
}
}
-12
View File
@@ -1,12 +0,0 @@
pub mod account;
pub mod profile;
pub mod session;
pub mod two_factor;
pub use session::Session;
pub use account::{generate_invite, invite_page, login, login_page, signup, signup_page};
pub use profile::{change_display_name, change_password, change_username, delete_account};
pub use two_factor::{
confirm_totp, disable_totp, get_totp, get_totp_status, mfa_page, verify_totp,
};
-143
View File
@@ -1,143 +0,0 @@
use argon2::{
Argon2, PasswordHash, PasswordHasher, PasswordVerifier,
password_hash::{SaltString, rand_core::OsRng},
};
use rocket::{http::Status, serde::json::Json};
use rocket_db_pools::Connection;
use serde::{Deserialize, Serialize};
use crate::{auth::Session, db::Postgres, user::User};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PasswordForm {
old_password: String,
new_password: String,
}
#[post("/settings/password", data = "<form>")]
pub async fn change_password(
session: Session,
mut db: Connection<Postgres>,
form: Json<PasswordForm>,
) -> Result<(), Status> {
let mut user = User::get_by_id(session.user_id, &mut db)
.await
.ok_or(Status::NotFound)
.inspect_err(|_| {
tracing::error!(
"Valid session does not have a valid user. ID: {}",
session.user_id
)
})?;
user.verify_password(&form.old_password)?;
// old password is correct, so new one can be set.
let salt = SaltString::generate(&mut OsRng);
let hashed = Argon2::default()
.hash_password(form.new_password.as_bytes(), &salt)
.inspect_err(|e| tracing::error!("failed to hash password! {e}"))
.map_err(|_| Status::InternalServerError)?
.to_string();
user.set_pass_hash(hashed, &mut db)
.await
.inspect_err(|e| tracing::error!("{e}"))
.map_err(|_| Status::InternalServerError)?;
Ok(())
}
#[derive(Deserialize, Debug, Clone)]
pub struct DisplayNameForm {
pub display_name: Option<String>,
}
#[derive(Deserialize, Debug, Clone)]
pub struct PasswordAnd2fa {
pub password: String,
pub totp_code: Option<String>,
}
#[delete("/settings", data = "<data>")]
pub async fn delete_account(
session: Session,
mut db: Connection<Postgres>,
data: Json<PasswordAnd2fa>,
) -> Result<(), Status> {
let mut user = User::get_by_id(session.user_id, &mut db)
.await
.ok_or(Status::NotFound)
.inspect_err(|_| {
tracing::error!(
"Valid session does not have a valid user. ID: {}",
session.user_id
)
})?;
user.verify_password(&data.password)?;
if user.twofa_enabled {
user.verify_2fa(data.totp_code.as_deref().unwrap_or(""))?;
}
user.delete(&mut db)
.await
.inspect_err(|e| tracing::error!("{e}"))
.map_err(|_| Status::InternalServerError)?;
Ok(())
}
#[patch("/settings/display_name", data = "<new>")]
pub async fn change_display_name(
session: Session,
mut db: Connection<Postgres>,
new: Json<DisplayNameForm>,
) -> Result<(), Status> {
let mut user = User::get_by_id(session.user_id, &mut db)
.await
.ok_or(Status::NotFound)
.inspect_err(|_| {
tracing::error!(
"Valid session does not have a valid user. ID: {}",
session.user_id
)
})?;
user.set_display_name(new.display_name.clone(), &mut db)
.await
.inspect_err(|e| tracing::error!("{e}"))
.map_err(|_| Status::InternalServerError)?;
Ok(())
}
#[derive(Deserialize)]
pub struct UsernameForm {
username: String,
}
#[patch("/settings/username", data = "<new>")]
pub async fn change_username(
session: Session,
mut db: Connection<Postgres>,
new: Json<UsernameForm>,
) -> Result<(), Status> {
let mut user = User::get_by_id(session.user_id, &mut db)
.await
.ok_or(Status::NotFound)
.inspect_err(|_| {
tracing::error!(
"Valid session does not have a valid user. ID: {}",
session.user_id
)
})?;
user.set_username(new.username.clone(), &mut db)
.await
.inspect_err(|e| tracing::error!("{e}"))
.map_err(|_| Status::InternalServerError)?;
Ok(())
}
-301
View File
@@ -1,301 +0,0 @@
use futures_util::TryFutureExt;
use rocket::{
Request,
http::Status,
outcome::{Outcome, try_outcome},
request::{self, FromRequest},
response::status::{self},
serde::json::Json,
};
use rocket_db_pools::Connection;
use rocket_dyn_templates::{Template, context};
use serde::{Deserialize, Serialize};
use totp_rs::{Algorithm, Secret, TOTP};
use crate::{
auth::{
account::AuthResponse,
profile::PasswordAnd2fa,
session::{Claims, Session, TokenScope},
},
db::Postgres,
user::User,
};
// Utility methods
pub fn totp_gen(user_id: usize, secret: &[u8]) -> Result<TOTP, String> {
TOTP::new(
Algorithm::SHA1,
6,
1,
30,
secret.to_owned(),
Some("chat.zxq5.dev".to_string()),
format!("{}", user_id),
)
.map_err(|_| String::from("Invalid Secret"))
}
// pages
#[get("/totp")]
pub async fn mfa_page(_session: Session) -> Template {
Template::render("2fa", context!())
}
#[post("/totp", data = "<form>")]
pub async fn confirm_totp(
mfa: TOTPSecret,
form: Json<TOTPSixDigitCode>,
mut db: Connection<Postgres>,
) -> Result<(), status::Custom<&'static str>> {
if form.code.len() != 6 || form.code.parse::<u32>().is_err() {
return Err(status::Custom(Status::BadRequest, "Invalid 6-digit code"));
}
println!("valid");
let totp = totp_gen(mfa.user_id, mfa.secret.as_bytes())
.map_err(|_| status::Custom(Status::InternalServerError, "TOTP Error"))?;
if !totp.check_current(&form.code).unwrap_or(false) {
return Err(status::Custom(Status::BadRequest, "Incorrect code"));
}
println!("correct");
if sqlx::query!(
"UPDATE users SET twofa_enabled = true WHERE id = $1",
mfa.user_id as i32
)
.execute(&mut **db)
.await
.is_err()
{
return Err(status::Custom(
Status::InternalServerError,
"unable to enable 2fa",
));
};
println!("enabled");
Ok(())
}
#[derive(Deserialize)]
pub struct PasswordConfirmation {
password: String,
}
#[post("/totp.jpg", data = "<form>")]
pub async fn get_totp(
mfa: TOTPSecret,
form: Json<PasswordConfirmation>,
) -> Option<Json<QrResponse>> {
let qr_b64 = totp_gen(mfa.user_id, mfa.secret.as_bytes())
.expect("Invalid TOTP")
.get_qr_base64()
.unwrap();
Some(Json(QrResponse {
qr_code: format!("data:image/png;base64,{}", qr_b64),
}))
}
#[derive(Debug, Deserialize)]
pub struct TOTPSixDigitCode {
code: String,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TotpStatus {
Enabled,
Disabled,
}
pub struct TOTPSecret {
user_id: usize,
secret: String,
}
#[derive(Serialize)]
pub struct QrResponse {
qr_code: String,
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for TOTPSecret {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
let auth_header = request.headers().get_one("Authorization");
println!(
"TOTPSecret guard - Auth header present: {}",
auth_header.is_some()
);
let user = try_outcome!(request.guard::<Claims>().await);
println!(
"TOTPSecret guard - Claims ok, user: {}, scope: {:?}",
user.sub, user.scope
);
// only allow full tokens for TOTP setup
if user.scope != TokenScope::Full {
println!("TOTPSecret guard - rejected, scope is {:?}", user.scope);
return Outcome::Error((Status::Forbidden, ()));
}
let user = try_outcome!(request.guard::<Session>().await);
let mut pool = match request.guard::<Connection<Postgres>>().await {
Outcome::Success(pool) => pool,
_ => return Outcome::Error((Status::Unauthorized, ())),
};
let row = sqlx::query!(
"SELECT twofa_enabled, totp_secret FROM users WHERE id = $1",
user.user_id as i32
)
.fetch_one(&mut **pool)
.await;
let (enabled, mut secret) = match row {
Ok(r) => (r.twofa_enabled, r.totp_secret),
Err(_) => return Outcome::Error((Status::Unauthorized, ())),
};
if secret.is_none() {
let new_secret = Secret::generate_secret().to_encoded().to_string();
sqlx::query!(
"UPDATE users SET totp_secret = $1 WHERE id = $2",
new_secret,
user.user_id as i32
)
.execute(&mut **pool)
.await
.ok();
secret = Some(new_secret);
}
Outcome::Success(TOTPSecret {
user_id: user.user_id,
secret: secret.unwrap(),
})
}
}
impl TOTPSecret {
pub async fn enable(&self, db: &mut Connection<Postgres>) -> Result<(), ()> {
match sqlx::query!(
"UPDATE users SET twofa_enabled = true WHERE id = $1",
self.user_id as i32,
)
.execute(&mut ***db)
.await
{
Ok(_) => Ok(()),
Err(_) => Err(()),
}
}
}
#[derive(Deserialize)]
pub struct TotpVerifyRequest {
pub code: String,
}
#[get("/totp/status")]
pub async fn get_totp_status(
user: Session,
mut db: Connection<Postgres>,
) -> Result<Json<TotpStatus>, Status> {
Ok(Json(
if sqlx::query!(
"SELECT twofa_enabled FROM users WHERE id = $1",
user.user_id as i32,
)
.fetch_one(&mut **db)
.await
.map_err(|_| Status::NotFound)?
.twofa_enabled
{
TotpStatus::Enabled
} else {
TotpStatus::Disabled
},
))
}
#[delete("/totp", data = "<form>")]
pub async fn disable_totp(
user: Session,
mut db: Connection<Postgres>,
form: Json<PasswordAnd2fa>,
) -> Result<Json<AuthResponse>, Status> {
let totp_code = form.totp_code.clone().ok_or(Status::BadRequest)?;
let mut user = User::get_by_id(user.user_id, &mut db)
.await
.ok_or(Status::NotFound)?;
user.verify_password(&form.password)?;
user.verify_2fa(&totp_code)?;
user.set_twofa_enabled(false, &mut db)
.await
.map_err(|_| Status::InternalServerError)?;
Ok(Json(AuthResponse {
token: Claims::new(user.id as usize, TokenScope::Full).encode(),
totp_required: false,
}))
}
#[post("/totp/verify", data = "<body>")]
pub async fn verify_totp(
claims: Claims, // request guard checks token validity
mut db: Connection<Postgres>,
body: Json<TotpVerifyRequest>,
) -> Result<Json<AuthResponse>, Status> {
println!("reached 1");
// reject if they somehow got here with a full token
if claims.scope != TokenScope::TotpPending {
return Err(Status::Forbidden);
}
println!("reached 2");
let row = sqlx::query!(
"SELECT totp_secret FROM users WHERE id = $1 AND twofa_enabled = TRUE",
claims.sub
)
.fetch_one(&mut **db)
.await
.map_err(|_| Status::Unauthorized)?;
println!("reached 3");
let totp = totp_gen(
claims.sub as usize,
row.totp_secret
.expect("user with 2fa enabled has no totp secret")
.as_bytes(),
)
.map_err(|_| Status::InternalServerError)?;
if !totp
.check_current(&body.code)
.map_err(|_| Status::InternalServerError)?
{
return Err(Status::Unauthorized);
}
println!("reached 5");
let claims = Claims::new(claims.sub as usize, TokenScope::Full);
Ok(Json(AuthResponse {
token: claims.encode(),
totp_required: false,
}))
}
+96
View File
@@ -0,0 +1,96 @@
use clap::{Parser, Subcommand};
use sqlx::postgres::PgPoolOptions;
use std::time::Duration;
use std::sync::Arc;
use crate::repo::user_repo::UserRepository;
use crate::repo::space_repo::SpaceRepository;
use crate::repo::channel_repo::ChannelRepository;
use crate::repo::{UserRepo, SpaceRepo, ChannelRepo};
use argon2::{
password_hash::{PasswordHasher, SaltString},
Argon2,
};
use rand::rngs::OsRng;
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
pub struct Cli {
#[command(subcommand)]
pub command: Option<Commands>,
}
#[derive(Subcommand)]
pub enum Commands {
/// First-time setup for the server
Setup {
/// Admin username
#[arg(short, long)]
username: String,
/// Admin password
#[arg(short, long)]
password: String,
/// Default space name
#[arg(short, long, default_value = "Default Space")]
space: String,
/// Default channel name
#[arg(short, long, default_value = "general")]
channel: String,
},
}
pub async fn handle_cli() -> bool {
let cli = Cli::parse();
match cli.command {
Some(Commands::Setup { username, password, space, channel }) => {
if let Err(e) = run_setup(username, password, space, channel).await {
eprintln!("Setup failed: {}", e);
std::process::exit(1);
}
println!("Setup completed successfully!");
true
}
None => false,
}
}
async fn run_setup(username: String, password: String, space_name: String, channel_name: String) -> Result<(), Box<dyn std::error::Error>> {
dotenv::dotenv().ok();
let db_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
let pool = PgPoolOptions::new()
.max_connections(1)
.acquire_timeout(Duration::from_secs(5))
.connect(&db_url)
.await?;
let user_repo = UserRepository::new(pool.clone());
let space_repo = SpaceRepository::new(pool.clone());
let channel_repo = ChannelRepository::new(pool.clone());
// 1. Create admin user
println!("Creating admin user: {}...", username);
let argon2 = Argon2::default();
let salt = SaltString::generate(&mut OsRng);
let passhash = argon2
.hash_password(password.as_bytes(), &salt)
.map_err(|e| e.to_string())?
.to_string();
let user_id = user_repo.new_user("admin@localhost", &username, &passhash).await?;
user_repo.set_role(user_id, "admin").await?;
// 2. Create default space
println!("Creating default space: {}...", space_name);
let space_id = space_repo.create(&space_name, Some("Default space created during setup"), user_id).await?;
// 3. Create default channel
println!("Creating default channel: {}...", channel_name);
channel_repo.create(&channel_name, Some("Default channel"), space_id).await?;
Ok(())
}
-9
View File
@@ -1,9 +0,0 @@
use rocket_db_pools::{Database, deadpool_redis};
#[derive(Database)]
#[database("postgres_db")]
pub struct Postgres(sqlx::PgPool);
#[derive(Database)]
#[database("redis_cache")]
pub struct Redis(deadpool_redis::Pool);
+127
View File
@@ -0,0 +1,127 @@
// error.rs
use rocket::{http::Status, response::{self, Responder}, Request, Response};
use thiserror::Error;
use rocket_dyn_templates::Template;
use rocket::serde::Serialize;
#[derive(Error, Debug)]
pub enum AppError {
#[error("Not found")]
NotFound,
#[error("Unauthorized")]
Unauthorised(String),
#[error("Forbidden")]
Forbidden,
#[error("Bad request: {0}")]
BadRequest(String),
#[error("Database error: {0}")]
Database(#[from] sqlx::Error),
#[error("Internal error: {0}")]
Internal(String),
}
impl AppError {
pub fn internal(msg: impl Into<String>) -> Self {
Self::Internal(msg.into())
}
pub fn bad_request(msg: impl Into<String>) -> Self {
Self::BadRequest(msg.into())
}
pub fn unauthorised(msg: impl Into<String>) -> Self {
Self::Unauthorised(msg.into())
}
}
impl<'r> Responder<'r, 'static> for AppError {
fn respond_to(self, _req: &'r Request<'_>) -> response::Result<'static> {
let status = match &self {
AppError::NotFound => Status::NotFound,
AppError::Unauthorised(_) => Status::Unauthorized,
AppError::Forbidden => Status::Forbidden,
AppError::BadRequest(_) => Status::BadRequest,
AppError::Database(_) => Status::InternalServerError,
AppError::Internal(_) => Status::InternalServerError,
};
// log internal errors
if status == Status::InternalServerError {
tracing::error!("Internal Server Error: {}", self);
}
Response::build()
.status(status)
.header(rocket::http::ContentType::Plain)
.sized_body(
self.to_string().len(),
std::io::Cursor::new(self.to_string())
)
.ok()
}
}
pub type ApiResult<T> = Result<T, AppError>;
#[derive(Serialize)]
struct ErrorContext {
error_code: u16,
error_message: &'static str,
additional_info: &'static str,
redirect: Option<RedirectContext>,
}
#[derive(Serialize)]
struct RedirectContext {
url: &'static str,
message: &'static str,
}
#[catch(404)]
pub async fn handle_404() -> Template {
Template::render(
"error",
ErrorContext {
error_code: 404,
error_message: "Not Found",
additional_info: "There's nothing here.",
redirect: Some(RedirectContext {
url: "/",
message: "Home",
}),
},
)
}
#[catch(401)]
pub async fn handle_401() -> Template {
Template::render(
"error",
ErrorContext {
error_code: 401,
error_message: "Unauthorised",
additional_info: "You are not authorised to access this resource.",
redirect: Some(RedirectContext {
url: "/login",
message: "Login",
}),
},
)
}
#[catch(default)]
pub async fn handle_default(status: Status, _request: &Request<'_>) -> Template {
Template::render(
"error",
ErrorContext {
error_code: status.code,
error_message: "Unknown Error",
additional_info: "I don't know what to do with this error.",
redirect: None,
},
)
}
-62
View File
@@ -1,62 +0,0 @@
use rocket::{Request, http::Status};
use rocket_dyn_templates::Template;
use serde::Serialize;
#[derive(Serialize)]
struct ErrorContext {
error_code: u16,
error_message: &'static str,
additional_info: &'static str,
redirect: Option<RedirectContext>,
}
#[derive(Serialize)]
struct RedirectContext {
url: &'static str,
message: &'static str,
}
#[catch(404)]
pub async fn handle_404() -> Template {
Template::render(
"error",
ErrorContext {
error_code: 404,
error_message: "Not Found",
additional_info: "There's nothing here.",
redirect: Some(RedirectContext {
url: "/",
message: "Home",
}),
},
)
}
#[catch(401)]
pub async fn handle_401() -> Template {
Template::render(
"error",
ErrorContext {
error_code: 401,
error_message: "Unauthorised",
additional_info: "You are not authorised to access this resource.",
redirect: Some(RedirectContext {
url: "/login",
message: "Login",
}),
},
)
}
#[catch(default)]
pub async fn handle_default(status: Status, _request: &Request<'_>) -> Template {
Template::render(
"error",
ErrorContext {
error_code: status.code,
error_message: "Unknown Error",
additional_info: "I don't know what to do with this error.",
redirect: None,
},
)
}
+149
View File
@@ -0,0 +1,149 @@
#![deny(clippy::unwrap_used)]
#![warn(clippy::all, clippy::nursery, clippy::cargo, clippy::pedantic)]
#[macro_use]
extern crate rocket;
pub mod messenger;
pub mod api;
pub mod repo;
pub mod error;
pub mod svc;
pub mod model;
pub mod cli;
use crate::repo::{access_token_repo::AccessTokenRepo, Repo};
use crate::repo::message_repo::MessageRepository;
use crate::repo::user_repo::UserRepository;
use crate::repo::space_repo::SpaceRepository;
use crate::repo::channel_repo::ChannelRepository;
use crate::svc::auth_svc::AuthService;
use crate::svc::chat_svc::ChatService;
use crate::svc::settings_svc::SettingsService;
use crate::svc::user_svc::UserService;
use rocket::fs::{FileServer, NamedFile};
use rocket::http::Method;
use rocket_cors::{AllowedOrigins, CorsOptions};
use rocket_dyn_templates::Template;
use sqlx::postgres::PgPoolOptions;
use std::env;
use std::sync::{Arc, LazyLock};
use std::time::Duration;
use api::cdn;
use crate::svc::access_token_svc::AccessTokenService;
use crate::svc::llm_service::LlmService;
pub fn rocket() -> rocket::Rocket<rocket::Build> {
if std::env::var("RELEASE_MODE").unwrap_or_default() != "1" {
dotenv::dotenv().ok();
}
let db_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
let pool = PgPoolOptions::new()
.max_connections(25)
.min_connections(5)
.acquire_timeout(Duration::from_secs(5))
.connect_lazy(&db_url)
.expect("Failed to create database pool");
let user_repo = Arc::new(UserRepository::new(pool.clone()));
let message_repo = MessageRepository::new(pool.clone());
let token_repo = Arc::new(AccessTokenRepo::new(pool.clone()));
let space_repo: Arc<dyn repo::SpaceRepo> = Arc::new(SpaceRepository::new(pool.clone()));
let channel_repo: Arc<dyn repo::ChannelRepo> = Arc::new(ChannelRepository::new(pool.clone()));
let llm_service = LlmService::new();
let chat_service = ChatService::new(32, llm_service.clone(), message_repo.clone(), user_repo.clone(), channel_repo.clone(), space_repo.clone());
rocket_builder(user_repo, token_repo, space_repo, channel_repo, chat_service)
}
pub fn rocket_builder(
user_repo: Arc<dyn repo::UserRepo>,
token_repo: Arc<dyn repo::AccessTokenRepoTrait>,
space_repo: Arc<dyn repo::SpaceRepo>,
channel_repo: Arc<dyn repo::ChannelRepo>,
chat_service: ChatService
) -> rocket::Rocket<rocket::Build> {
let cors = CorsOptions::default()
.allowed_origins(AllowedOrigins::all())
.allowed_methods(
vec![Method::Get, Method::Post, Method::Patch]
.into_iter()
.map(From::from)
.collect(),
)
.allow_credentials(true);
let access_token_svc = AccessTokenService::new(token_repo.clone());
let auth_service = AuthService::new(user_repo.clone(), access_token_svc.clone());
let settings_service = SettingsService::new(auth_service.clone(), user_repo.clone());
let user_service = UserService::new(user_repo.clone());
rocket::build()
.manage(chat_service)
.manage(auth_service)
.manage(settings_service)
.manage(user_service)
.manage(space_repo)
.manage(channel_repo)
.attach(cors.to_cors().unwrap())
.attach(Template::fairing())
.mount("/static", FileServer::from("static"))
.mount("/cdn", cdn::routes())
.mount(
"/",
routes![
favicon,
],
)
.mount(
"/api",
routes![
cdn::upload_profile_pic,
api::profile::display_name,
// basic auth
api::auth::login,
api::auth::signup,
// 2fa
api::totp::confirm_totp,
api::totp::disable_totp,
api::totp::get_totp,
api::totp::get_totp_status,
api::totp::verify_totp,
// chat
api::chat::event_stream,
api::chat::post_message,
// user settings
api::settings::change_display_name,
api::settings::change_password,
api::settings::change_username,
api::settings::delete_account,
// spaces
api::space::list_spaces,
api::space::list_channels,
api::space::get_accessible_channels
],
)
.register(
"/",
catchers![
error::handle_401,
error::handle_404,
error::handle_default,
],
)
}
#[get("/favicon.ico")]
pub async fn favicon() -> NamedFile {
NamedFile::open("static/favicon.ico").await.unwrap()
}
-69
View File
@@ -1,69 +0,0 @@
// src/llm.rs
use serde::{Deserialize, Serialize};
use crate::messenger::ChatMsg;
#[derive(Serialize)]
struct LlmRequest {
model: String,
messages: Vec<Message>,
}
#[derive(Serialize, Deserialize)]
struct Message {
role: String, // "user" or "assistant"
content: String,
}
pub struct LlmWorker {
uri: String,
}
impl LlmWorker {
pub fn new(uri: String) -> Self {
Self { uri }
}
pub async fn query(&self, message: &ChatMsg) -> Result<ChatMsg, String> {
let client = reqwest::Client::new();
// Build the request body
let payload = LlmRequest {
model: "gpt-oss-20b".into(), // whatever model you run locally
messages: vec![Message {
role: "user".into(),
content: message.text.clone(),
}],
};
// POST to lmstudio (default 127.0.0.1:1234)
let resp = client
.post(self.uri.clone())
.json(&payload)
.send()
.await
.map_err(|_| String::from("Failed to make request to LLM server"))?;
// The API returns a JSON with `choices[].message.content`
#[derive(Deserialize)]
struct LlmResponse {
choices: Vec<Choice>,
}
#[derive(Deserialize)]
struct Choice {
message: Message,
}
let llm_resp: LlmResponse = resp
.json()
.await
.map_err(|_| String::from("Failed to make request to LLM server"))?;
Ok(ChatMsg {
display_name: Some(String::from("lmstudio")),
user_id: 0,
text: llm_resp.choices[0].message.content.clone(),
timestamp: chrono::Utc::now().timestamp_millis() as usize,
})
}
}
+8 -95
View File
@@ -1,98 +1,11 @@
// src/main.rs
#[macro_use]
extern crate rocket;
use backend::rocket;
use backend::cli::handle_cli;
use rocket::fs::{FileServer, NamedFile};
use rocket::http::Method;
use rocket::{Build, Rocket};
use rocket_cors::{AllowedOrigins, CorsOptions};
use rocket_db_pools::Database;
use rocket_dyn_templates::Template;
use std::env;
use std::sync::{Arc, LazyLock};
use crate::db::{Postgres, Redis};
pub mod auth;
pub mod cdn;
pub mod db;
pub mod handlers;
pub mod llm;
pub mod messenger;
pub mod user;
static LMSTUDIO_URL: LazyLock<String> =
LazyLock::new(|| env::var("LMSTUDIO_URL").expect("Ensure LMSTUDIO_URL is set!"));
#[launch]
fn rocket() -> Rocket<Build> {
// make sure the env is loaded
dotenv::dotenv().expect("Failed to load env! aborting launch!");
let chat = Arc::new(crate::messenger::ChatBroadcaster::new(32));
let cors = CorsOptions::default()
.allowed_origins(AllowedOrigins::all())
.allowed_methods(
vec![Method::Get, Method::Post, Method::Patch]
.into_iter()
.map(From::from)
.collect(),
)
.allow_credentials(true);
rocket::build()
.manage(chat)
.attach(cors.to_cors().unwrap())
.attach(Postgres::init())
.attach(Redis::init())
.attach(Template::fairing())
.mount("/static", FileServer::from("static"))
.mount("/cdn", cdn::routes())
.mount(
"/",
routes![
favicon,
messenger::chat_page,
auth::signup_page,
auth::login_page,
auth::mfa_page,
auth::invite_page,
],
)
.mount(
"/api",
routes![
cdn::upload_profile_pic,
messenger::post_message,
messenger::event_stream,
user::users,
user::display_name,
auth::signup,
auth::login,
auth::get_totp,
auth::confirm_totp,
auth::generate_invite,
auth::verify_totp,
auth::disable_totp,
auth::get_totp_status,
auth::change_password,
auth::change_display_name,
auth::change_username,
auth::delete_account,
],
)
.register(
"/",
catchers![
handlers::handle_404,
handlers::handle_401,
handlers::handle_default
],
)
#[rocket::main]
async fn main() -> Result<(), rocket::Error> {
if handle_cli().await {
return Ok(());
}
#[get("/favicon.ico")]
async fn favicon() -> NamedFile {
NamedFile::open("static/favicon.ico").await.unwrap()
rocket().launch().await?;
Ok(())
}
+2 -4
View File
@@ -1,10 +1,8 @@
use redis::AsyncCommands;
use rocket_db_pools::Connection;
use crate::{
db::{Postgres, Redis},
messenger::ChatMsg,
};
use crate::api::chat::ChatMsg;
use crate::db::{Postgres, Redis};
// Helper function to cache message in Redis
pub async fn insert(
-220
View File
@@ -1,220 +0,0 @@
use std::sync::Arc;
use rocket::{
Shutdown,
response::stream::{Event, EventStream},
serde::json::Json,
time::OffsetDateTime,
};
use rocket_db_pools::Connection;
use rocket_dyn_templates::{Template, context};
use serde::{Deserialize, Serialize};
use sqlx::prelude::FromRow;
use tokio::{select, sync::broadcast};
use crate::{
auth::Session,
db::{Postgres, Redis},
llm::LlmWorker,
messenger,
};
/// ---------- shared broadcaster ----------
pub struct ChatBroadcaster {
buffer_size: usize,
senders: std::sync::Mutex<std::collections::HashMap<i32, broadcast::Sender<ChatMsg>>>,
}
impl ChatBroadcaster {
pub fn new(buffer_size: usize) -> Self {
Self {
buffer_size,
senders: std::sync::Mutex::new(std::collections::HashMap::new()),
}
}
/// Publish a message to the specified channel.
pub async fn publish(&self, channel_id: i32, msg: ChatMsg) {
let mut map = self.senders.lock().unwrap();
let sender = map
.entry(channel_id)
.or_insert_with(|| broadcast::channel::<ChatMsg>(self.buffer_size).0);
let _ = sender.send(msg);
}
/// Subscribe to the specified channel.
pub fn subscribe(&self, channel_id: i32) -> broadcast::Receiver<ChatMsg> {
let mut map = self.senders.lock().unwrap();
let sender = map
.entry(channel_id)
.or_insert_with(|| broadcast::channel::<ChatMsg>(self.buffer_size).0);
sender.subscribe()
}
}
/// ---------- Rocket routes ----------
#[derive(Debug, Serialize, Deserialize, Clone, FromRow)]
pub struct ChatMsg {
pub display_name: Option<String>,
pub user_id: usize,
pub text: String,
pub timestamp: usize,
}
#[post("/chat/<channel_id>", format = "json", data = "<msg>")]
pub async fn post_message(
mut msg: Json<ChatMsg>,
chat: &rocket::State<Arc<ChatBroadcaster>>,
mut postgres: Connection<Postgres>,
mut cache: Option<Connection<Redis>>,
session: Session,
channel_id: i32,
) -> Result<(), String> {
let chat = chat.inner().clone();
let display_name = sqlx::query!(
"SELECT display_name, username FROM users WHERE id = $1",
session.user_id as i32
)
.fetch_one(&mut **postgres)
.await
.map(|row| row.display_name.unwrap_or(row.username))
.unwrap_or_else(|_| "Unknown".to_string());
msg.user_id = session.user_id;
msg.display_name = Some(display_name);
chat.publish(channel_id, msg.clone().into_inner()).await;
sqlx::query!(
"INSERT INTO messages (channel_id, user_id, content, created_at) VALUES ($1, $2, $3, $4)",
channel_id,
msg.user_id as i32,
msg.text,
OffsetDateTime::from_unix_timestamp_nanos(msg.timestamp as i128 * 1_000_000).unwrap()
)
.execute(&mut **postgres)
.await
.map_err(|_| "Failed".to_string())?;
println!("gisfujdeghnjuisdfjngiosdfgjkosdf gnojdfsg nmodfsg");
if let Some(ref mut cache) = cache {
messenger::cache::insert(cache, channel_id, &msg)
.await
.map_err(|_| "Redis cache failed".to_string())?;
}
// get response
tokio::spawn(async move {
let response = LlmWorker::new(crate::LMSTUDIO_URL.to_string())
.query(&msg)
.await;
if let Ok(reply) = response {
chat.publish(channel_id, reply.clone()).await;
if let Some(ref mut cache) = cache {
messenger::cache::insert(cache, channel_id, &reply)
.await
.ok();
}
sqlx::query!(
"INSERT INTO messages (channel_id, user_id, content, created_at) VALUES ($1, $2, $3, $4)",
channel_id,
reply.user_id as i32,
reply.text,
OffsetDateTime::from_unix_timestamp_nanos(reply.timestamp as i128 * 1_000_000).unwrap()
)
.execute(&mut **postgres)
.await
.map_err(|_| "Failed".to_string())
.unwrap();
}
});
Ok(())
}
pub async fn get_messages(
mut db: Connection<Postgres>,
mut redis: Connection<Redis>,
channel_id: i32,
) -> Json<Vec<ChatMsg>> {
if let Ok(messages) = messenger::cache::get(&mut redis, channel_id).await
&& !messages.is_empty()
{
return Json(messages);
};
if let Err(x) = messenger::cache::initialise(&mut redis, &mut db, channel_id).await {
eprintln!("WARN: {x:?}");
}
if let Ok(messages) = messenger::cache::get(&mut redis, channel_id).await
&& !messages.is_empty()
{
return Json(messages);
};
let res = sqlx::query!(
"SELECT u.username, u.display_name, u.id, m.content, m.created_at
FROM messages m
JOIN users u ON m.user_id = u.id
WHERE m.channel_id = $1
ORDER BY m.created_at DESC LIMIT 100",
channel_id
)
.fetch_all(&mut **db)
.await
.unwrap_or_else(|_| Vec::new())
.into_iter()
.rev()
.map(|msg| ChatMsg {
display_name: Some(msg.display_name.unwrap_or(msg.username)),
user_id: msg.id as usize,
text: msg.content,
timestamp: (msg.created_at.unwrap().unix_timestamp_nanos() / 1_000_000) as usize,
})
.collect();
Json(res)
}
#[get("/events/<channel_id>")]
pub async fn event_stream(
chat: &rocket::State<Arc<ChatBroadcaster>>,
postgres: Connection<Postgres>,
cache: Connection<Redis>,
_session: Session,
mut shutdown: Shutdown,
channel_id: i32,
) -> EventStream![] {
let mut rx = chat.subscribe(channel_id);
EventStream! {
// Initialize the stream with the last 100 messages
for msg in get_messages(postgres, cache, channel_id).await.0 {
yield Event::json(&msg);
}
loop {
select!{
// exit early on shutdown
_ = &mut shutdown => break,
msg = rx.recv() => match msg {
Ok(msg) => yield Event::json(&msg),
Err(broadcast::error::RecvError::Lagged(_)) => {
yield Event::comment("RecvError::Lagged");
}
Err(broadcast::error::RecvError::Closed) => break,
},
}
}
}
}
#[get("/chat")]
pub async fn chat_page(session: Session) -> Template {
Template::render("chat", context!(user_id: session.user_id))
}
+1 -4
View File
@@ -1,4 +1 @@
mod cache;
mod messages;
pub use messages::{ChatBroadcaster, ChatMsg, chat_page, event_stream, get_messages, post_message};
// mod cache;
+35
View File
@@ -0,0 +1,35 @@
use rocket::serde::{Deserialize, Serialize};
use chrono::{DateTime, Utc};
#[derive(Serialize, Deserialize)]
pub struct SignupCredentials {
pub email: String,
pub username: String,
pub password: String,
pub access_token: String,
}
#[derive(Serialize, Deserialize)]
pub struct LoginCredentials {
pub username: String,
pub password: String,
}
#[derive(Serialize, Deserialize)]
pub struct AuthResponse {
pub token: String,
pub totp_required: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessTokenForm {
pub name: String,
pub max_uses: i32,
pub expiry_date: DateTime<Utc>,
pub start_date: DateTime<Utc>,
}
pub struct AccessToken {
pub id: i64,
pub code: String,
}
+3
View File
@@ -0,0 +1,3 @@
pub mod auth;
pub mod user;
pub mod space;
+35
View File
@@ -0,0 +1,35 @@
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
use chrono::{DateTime, Utc};
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct Space {
pub id: i64,
pub name: String,
pub description: Option<String>,
pub owner_id: i64,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct Channel {
pub id: i64,
pub name: String,
pub description: Option<String>,
pub space_id: i64,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpaceDto {
pub channels: Vec<Channel>,
pub id: i64,
pub owner_id: i64,
pub name: String,
pub description: Option<String>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
+52
View File
@@ -0,0 +1,52 @@
use crate::api::auth::Session;
use crate::error::ApiResult;
use crate::svc::user_svc::UserService;
use chrono::{DateTime, Utc};
use rocket::State;
use sqlx::FromRow;
use crate::api::totp::TotpStatus;
#[derive(Clone)]
#[derive(FromRow)]
pub struct User {
pub id: i64,
pub email: Option<String>,
pub username: String,
pub nickname: Option<String>,
pub passhash: String,
pub totp_status: TotpStatus,
pub totp_secret: Option<String>,
pub created_at: Option<DateTime<Utc>>,
pub updated_at: Option<DateTime<Utc>>,
}
// pub struct UserCache {}
//
// impl UserCache {
// pub async fn username(
// id: usize,
// redis_conn: &mut Connection<Redis>,
// pgsql_conn: &mut Connection<Postgres>,
// ) -> String {
// if let Ok(val) = redis_conn.get(format!("users:{id}")).await {
// return val;
// }
//
// if let Ok(v) = sqlx::query!("SELECT username FROM users WHERE id = $1", id as i32)
// .fetch_one(&mut ***pgsql_conn)
// .await
// {
// let username = v.username;
// Self::insert(id, &username, redis_conn).await;
// username
// } else {
// unimplemented!()
// }
// }
//
// pub async fn insert(id: usize, username: &str, conn: &mut Connection<Redis>) {
// conn.set_ex::<_, _, ()>(format!("users:{id}"), username.to_string(), 1800)
// .await
// .expect("failed to insert key");
// }
// }
+67
View File
@@ -0,0 +1,67 @@
use crate::repo::{Repo, AccessTokenRepoTrait};
use chrono::{DateTime, Utc};
use sqlx::PgPool;
use crate::model::auth::AccessToken;
#[derive(Clone)]
pub struct AccessTokenRepo {
pool: PgPool
}
impl Repo for AccessTokenRepo {
type Target = AccessToken;
fn new(pool: PgPool) -> Self {
Self { pool }
}
async fn get_by_id(&self, id: i64) -> Option<Self::Target> {
sqlx::query_as!(AccessToken, "SELECT id, code FROM access_tokens WHERE id = $1", id)
.fetch_optional(&self.pool)
.await
.unwrap_or(None)
}
}
#[async_trait]
impl AccessTokenRepoTrait for AccessTokenRepo {
async fn get_by_id(&self, id: i64) -> Option<AccessToken> {
Repo::get_by_id(self, id).await
}
async fn create_new(&self,
uid: i64, name: &str, code: &str, max_uses: i32,
start_date: DateTime<Utc>, expiry_date: DateTime<Utc>
) -> Result<i64, sqlx::Error> {
sqlx::query!(
"INSERT INTO access_tokens (name, code, creator_id, max_uses, created_at, expires_at)
VALUES ($1, $2, $3, $4, $5, $6) RETURNING id",
name,
code,
uid,
max_uses,
start_date,
expiry_date
).fetch_optional(&self.pool).await.and_then(|row| row.map(|r| r.id).ok_or(sqlx::Error::RowNotFound))
}
async fn use_token(&self, id: i64) -> Result<(), sqlx::Error> {
sqlx::query!("UPDATE access_tokens SET uses = uses + 1 WHERE id = $1", id)
.execute(&self.pool)
.await?;
Ok(())
}
async fn get_code_not_expired(&self, code: &str) -> Result<Option<AccessToken>, sqlx::Error> {
sqlx::query_as!(AccessToken,
"SELECT id, code FROM access_tokens
WHERE code = $1
AND created_at < NOW()
AND expires_at > NOW()
AND uses < max_uses",
code
)
.fetch_optional(&self.pool)
.await
}
}
+39
View File
@@ -0,0 +1,39 @@
use crate::repo::ChannelRepo;
use crate::model::space::Channel;
use sqlx::PgPool;
#[derive(Clone)]
pub struct ChannelRepository {
pool: PgPool,
}
impl ChannelRepository {
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
}
#[rocket::async_trait]
impl ChannelRepo for ChannelRepository {
async fn create(&self, name: &str, description: Option<&str>, space_id: i64) -> Result<i64, sqlx::Error> {
let row = sqlx::query!(
"INSERT INTO channels (name, description, space_id) VALUES ($1, $2, $3) RETURNING id",
name,
description,
space_id
)
.fetch_one(&self.pool)
.await?;
Ok(row.id)
}
async fn get_by_space_id(&self, space_id: i64) -> Result<Vec<Channel>, sqlx::Error> {
sqlx::query_as!(
Channel,
"SELECT id, name, description, space_id, created_at as \"created_at!\", updated_at as \"updated_at!\" FROM channels WHERE space_id = $1",
space_id
)
.fetch_all(&self.pool)
.await
}
}
+95
View File
@@ -0,0 +1,95 @@
use crate::api::chat::ChatMsg;
use crate::repo::Repo;
use chrono::{DateTime, Utc};
use sqlx::PgPool;
#[derive(Clone)]
pub struct MessageRepository {
pool: PgPool
}
impl Repo for MessageRepository {
type Target = ChatMsg;
fn new(pool: PgPool) -> Self {
Self { pool }
}
// TODO: caching with redis
async fn get_by_id(&self, id: i64) -> Option<Self::Target> {
sqlx::query!(
"SELECT u.username, u.nickname, u.id as user_id, m.content, m.created_at
FROM messages m
JOIN users u ON m.user_id = u.id
WHERE m.id = $1",
id
).fetch_optional(&self.pool).await.ok().flatten().map(|row| ChatMsg {
display_name: Some(row.nickname.unwrap_or(row.username)),
user_id: row.user_id,
text: row.content,
timestamp: row.created_at,
})
}
}
impl MessageRepository {
// TODO! caching with redis
pub async fn create_new(
&self, uid: i64, channel_id: i64,
text: &str, created_at: DateTime<Utc>
) -> Result<i64, sqlx::Error> {
sqlx::query!(
"INSERT INTO messages (channel_id, user_id, content, created_at)
VALUES ($1, $2, $3, $4) RETURNING id",
channel_id,
uid,
text,
created_at
).fetch_optional(&self.pool).await.and_then(|row| row.map(|r| r.id).ok_or(sqlx::Error::RowNotFound))
}
/// TODO: caching with redis
pub async fn get_by_channel(&self, channel_id: i64, limit: usize)
-> Result<Vec<ChatMsg>, sqlx::Error> {
sqlx::query!(
"SELECT u.username, u.nickname, u.id as user_id, m.content, m.created_at
FROM messages m
JOIN users u ON m.user_id = u.id
WHERE m.channel_id = $1
ORDER BY m.created_at DESC LIMIT $2",
channel_id,
limit as i64
).fetch_all(&self.pool).await.map(|messages| {
messages.into_iter().rev().map(|msg| {
ChatMsg {
display_name: Some(msg.nickname.unwrap_or(msg.username)),
user_id: msg.user_id,
text: msg.content,
timestamp: msg.created_at,
}
}).collect::<Vec<_>>()
})
}
}
+153
View File
@@ -0,0 +1,153 @@
use crate::repo::{UserRepo, AccessTokenRepoTrait};
use crate::model::user::User;
use crate::model::auth::AccessToken;
use rocket::async_trait;
use std::sync::Mutex;
use chrono::Utc;
use std::sync::Arc;
use sqlx::Error;
use crate::api::totp::TotpStatus;
use crate::api::totp::TotpStatus::Disabled;
pub struct MockAccessTokenRepo {
pub tokens: Mutex<Vec<AccessToken>>,
}
#[async_trait]
impl AccessTokenRepoTrait for MockAccessTokenRepo {
async fn get_by_id(&self, id: i64) -> Option<AccessToken> {
self.tokens.lock().unwrap().iter().find(|t| t.id == id).map(|t| AccessToken { id: t.id, code: t.code.clone() })
}
async fn create_new(&self, _uid: i64, _name: &str, code: &str, _max_uses: i32, _start_date: chrono::DateTime<Utc>, _expiry_date: chrono::DateTime<Utc>) -> Result<i64, sqlx::Error> {
let mut tokens = self.tokens.lock().unwrap();
let id = tokens.len() as i64 + 1;
tokens.push(AccessToken { id, code: code.to_string() });
Ok(id)
}
async fn use_token(&self, id: i64) -> Result<(), Error> {
// let mut tokens = self.tokens.lock().unwrap();
// if let Some(pos) = tokens.iter().position(|t| t.id == id) {
// tokens.get_mut(pos).uses =
// }
Ok(())
}
async fn get_code_not_expired(&self, code: &str) -> Result<Option<AccessToken>, Error> {
Ok(self.tokens.lock().unwrap()
.iter().find(|t| t.code == code)
.map(|t| AccessToken { id: t.id, code: t.code.clone() }))
}
}
pub struct MockUserRepo {
pub users: Mutex<Vec<User>>,
}
#[async_trait]
impl UserRepo for MockUserRepo {
fn pool(&self) -> &sqlx::PgPool {
unimplemented!("MockUserRepo does not have a real pool")
}
async fn get_by_id(&self, id: i64) -> Option<User> {
self.users.lock().unwrap().iter().find(|u| u.id == id).cloned()
}
async fn save(&self, user: &User) -> Result<(), sqlx::Error> {
let mut users = self.users.lock().unwrap();
if let Some(pos) = users.iter().position(|u| u.id == user.id) {
users[pos] = user.clone();
}
Ok(())
}
async fn new_user(&self, email: &str, username: &str, pass_hash: &str) -> Result<i64, sqlx::Error> {
let mut users = self.users.lock().unwrap();
let id = users.len() as i64 + 1;
users.push(User {
id,
email: Some(email.to_string()),
username: username.to_string(),
nickname: None,
passhash: pass_hash.to_string(),
totp_status: Disabled,
totp_secret: None,
created_at: Some(Utc::now()),
updated_at: Some(Utc::now()),
});
Ok(id)
}
async fn get_by_username(&self, username: &str) -> Result<Option<User>, sqlx::Error> {
Ok(self.users.lock().unwrap().iter().find(|u| u.username == username).cloned())
}
async fn delete_by_id(&self, id: i64) -> Result<(), sqlx::Error> {
self.users.lock().unwrap().retain(|u| u.id != id);
Ok(())
}
async fn set_display_name(&self, id: i64, display_name: Option<String>) -> Result<(), sqlx::Error> {
if let Some(u) = self.users.lock().unwrap().iter_mut().find(|u| u.id == id) {
u.nickname = display_name;
}
Ok(())
}
async fn set_username(&self, id: i64, username: &str) -> Result<(), sqlx::Error> {
if let Some(u) = self.users.lock().unwrap().iter_mut().find(|u| u.id == id) {
u.username = username.to_string();
}
Ok(())
}
async fn set_twofa_enabled(&self, id: i64, enabled: &TotpStatus) -> Result<(), sqlx::Error> {
if let Some(u) = self.users.lock().unwrap().iter_mut().find(|u| u.id == id) {
u.totp_status = *enabled;
}
Ok(())
}
async fn get_totp_secret(&self, id: i64) -> Result<Option<String>, sqlx::Error> {
Ok(self.users.lock().unwrap().iter().find(|u| u.id == id).and_then(|u| u.totp_secret.clone()))
}
async fn set_totp_secret(&self, id: i64, secret: Option<String>) -> Result<(), sqlx::Error> {
if let Some(u) = self.users.lock().unwrap().iter_mut().find(|u| u.id == id) {
u.totp_secret = secret;
}
Ok(())
}
async fn get_pass_hash(&self, id: i64) -> Result<String, sqlx::Error> {
Ok(self.users.lock().unwrap().iter().find(|u| u.id == id).map(|u| u.passhash.clone()).unwrap())
}
async fn set_pass_hash(&self, id: i64, pass_hash: &str) -> Result<(), sqlx::Error> {
if let Some(u) = self.users.lock().unwrap().iter_mut().find(|u| u.id == id) {
u.passhash = pass_hash.to_string();
}
Ok(())
}
async fn set_email(&self, id: i64, email: &str) -> Result<(), sqlx::Error> {
if let Some(u) = self.users.lock().unwrap().iter_mut().find(|u| u.id == id) {
u.email = Some(email.to_string());
}
Ok(())
}
async fn set_role(&self, _id: i64, _role: &str) -> Result<(), sqlx::Error> {
Ok(())
}
}
pub struct MockTokenRepo {
pub tokens: Mutex<Vec<AccessToken>>,
}
#[async_trait]
impl AccessTokenRepoTrait for MockTokenRepo {
async fn get_by_id(&self, id: i64) -> Option<AccessToken> {
self.tokens.lock().unwrap().iter().find(|t| t.id == id).map(|t| AccessToken { id: t.id, code: t.code.clone() })
}
async fn create_new(&self, _uid: i64, _name: &str, code: &str, _max_uses: i32, _start_date: chrono::DateTime<Utc>, _expiry_date: chrono::DateTime<Utc>) -> Result<i64, sqlx::Error> {
let mut tokens = self.tokens.lock().unwrap();
let id = tokens.len() as i64 + 1;
tokens.push(AccessToken { id, code: code.to_string() });
Ok(id)
}
async fn use_token(&self, _id: i64) -> Result<(), sqlx::Error> {
Ok(())
}
async fn get_code_not_expired(&self, code: &str) -> Result<Option<AccessToken>, sqlx::Error> {
Ok(self.tokens.lock().unwrap().iter().find(|t| t.code == code).map(|t| AccessToken { id: t.id, code: t.code.clone() }))
}
}
+64
View File
@@ -0,0 +1,64 @@
use crate::model::auth::AccessToken;
use crate::model::user::User;
use chrono::{DateTime, Utc};
use crate::api::totp::TotpStatus;
use crate::model::space::Space;
pub mod user_repo;
pub mod message_repo;
pub mod access_token_repo;
pub mod space_repo;
pub mod channel_repo;
pub mod mock;
pub trait Repo: Clone + Send + Sync {
type Target;
fn new(pool: sqlx::PgPool) -> Self;
async fn get_by_id(&self, id: i64) -> Option<Self::Target>;
}
#[rocket::async_trait]
pub trait UserRepo: Send + Sync {
fn pool(&self) -> &sqlx::PgPool;
async fn get_by_id(&self, id: i64) -> Option<User>;
async fn save(&self, user: &User) -> Result<(), sqlx::Error>;
async fn new_user(&self, email: &str, username: &str, pass_hash: &str) -> Result<i64, sqlx::Error>;
async fn get_by_username(&self, username: &str) -> Result<Option<User>, sqlx::Error>;
async fn delete_by_id(&self, id: i64) -> Result<(), sqlx::Error>;
async fn set_display_name(&self, id: i64, display_name: Option<String>) -> Result<(), sqlx::Error>;
async fn set_username(&self, id: i64, username: &str) -> Result<(), sqlx::Error>;
async fn set_twofa_enabled(&self, id: i64, enabled: &TotpStatus) -> Result<(), sqlx::Error>;
async fn get_totp_secret(&self, id: i64) -> Result<Option<String>, sqlx::Error>;
async fn set_totp_secret(&self, id: i64, secret: Option<String>) -> Result<(), sqlx::Error>;
async fn get_pass_hash(&self, id: i64) -> Result<String, sqlx::Error>;
async fn set_pass_hash(&self, id: i64, pass_hash: &str) -> Result<(), sqlx::Error>;
async fn set_email(&self, id: i64, email: &str) -> Result<(), sqlx::Error>;
async fn set_role(&self, id: i64, role: &str) -> Result<(), sqlx::Error>;
}
#[rocket::async_trait]
pub trait SpaceRepo: Send + Sync {
async fn create(&self, name: &str, description: Option<&str>, owner_id: i64) -> Result<i64, sqlx::Error>;
async fn get_all(&self) -> Result<Vec<crate::model::space::Space>, sqlx::Error>;
async fn get_by_member(&self, uid: i64) -> Result<Vec<Space>, sqlx::Error>;
async fn get_by_id(&self, id: i64) -> Result<Option<crate::model::space::Space>, sqlx::Error>;
}
#[rocket::async_trait]
pub trait ChannelRepo: Send + Sync {
async fn create(&self, name: &str, description: Option<&str>, space_id: i64) -> Result<i64, sqlx::Error>;
async fn get_by_space_id(&self, space_id: i64) -> Result<Vec<crate::model::space::Channel>, sqlx::Error>;
}
#[async_trait]
pub trait AccessTokenRepoTrait: Send + Sync {
async fn get_by_id(&self, id: i64) -> Option<AccessToken>;
async fn create_new(&self,
uid: i64, name: &str, code: &str, max_uses: i32,
start_date: DateTime<Utc>, expiry_date: DateTime<Utc>
) -> Result<i64, sqlx::Error>;
async fn use_token(&self, id: i64) -> Result<(), sqlx::Error>;
async fn get_code_not_expired(&self, code: &str) -> Result<Option<AccessToken>, sqlx::Error>;
}
+56
View File
@@ -0,0 +1,56 @@
use crate::repo::SpaceRepo;
use crate::model::space::Space;
use sqlx::PgPool;
#[derive(Clone)]
pub struct SpaceRepository {
pool: PgPool,
}
impl SpaceRepository {
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
}
#[rocket::async_trait]
impl SpaceRepo for SpaceRepository {
async fn create(&self, name: &str, description: Option<&str>, owner_id: i64) -> Result<i64, sqlx::Error> {
let row = sqlx::query!(
"INSERT INTO spaces (name, description, owner_id) VALUES ($1, $2, $3) RETURNING id",
name,
description,
owner_id
)
.fetch_one(&self.pool)
.await?;
Ok(row.id)
}
async fn get_all(&self) -> Result<Vec<Space>, sqlx::Error> {
sqlx::query_as!(Space,
"SELECT id, name, description, owner_id, created_at, updated_at FROM spaces"
)
.fetch_all(&self.pool)
.await
}
async fn get_by_member(&self, uid: i64) -> Result<Vec<Space>, sqlx::Error> {
sqlx::query_as!(Space,
"SELECT s.id, s.name, s.description, s.created_at, s.updated_at, s.owner_id
FROM spaces s JOIN space_members sm ON s.id = sm.space_id
WHERE sm.user_id = $1",
uid
).fetch_all(&self.pool)
.await
}
async fn get_by_id(&self, id: i64) -> Result<Option<Space>, sqlx::Error> {
sqlx::query_as!(Space,
"SELECT id, name, description, owner_id, created_at, updated_at FROM spaces WHERE id = $1",
id
)
.fetch_optional(&self.pool)
.await
}
}
+212
View File
@@ -0,0 +1,212 @@
use crate::repo::{Repo, UserRepo};
use crate::model::user::User;
use sqlx::PgPool;
use crate::api::totp::TotpStatus;
#[derive(Clone)]
pub struct UserRepository {
pool: PgPool
}
impl UserRepository {
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
pub fn pool(&self) -> &PgPool {
&self.pool
}
}
impl Repo for UserRepository {
type Target = User;
fn new(pool: PgPool) -> Self {
Self::new(pool)
}
async fn get_by_id(&self, id: i64) -> Option<Self::Target> {
sqlx::query_as!(
User,
"SELECT id, email, username, nickname, passhash, totp_status as \"totp_status!: TotpStatus\", totp_secret, created_at, updated_at FROM users WHERE id = $1",
id
)
.fetch_optional(&self.pool)
.await
.map_err(|e| {
tracing::error!("Database error in get_by_id: {}", e);
e
})
.ok()?
}
}
#[async_trait]
impl UserRepo for UserRepository {
fn pool(&self) -> &sqlx::PgPool {
&self.pool
}
async fn get_by_id(&self, id: i64) -> Option<User> {
Repo::get_by_id(self, id).await
}
async fn save(&self, user: &User) -> Result<(), sqlx::Error> {
sqlx::query!(
"UPDATE users SET email = $1, username = $2, nickname = $3, passhash = $4, totp_status = $5, totp_secret = $6, created_at = $7, updated_at = $8 WHERE id = $9",
user.email,
user.username,
user.nickname,
user.passhash,
user.totp_status as TotpStatus,
user.totp_secret,
user.created_at,
user.updated_at,
user.id
).execute(&self.pool).await?;
Ok(())
}
async fn new_user(&self, email: &str, username: &str, passhash: &str) -> Result<i64, sqlx::Error> {
sqlx::query!(
"INSERT INTO users (email, username, passhash) VALUES ($1, $2, $3) RETURNING id",
email,
username,
passhash
)
.fetch_optional(&self.pool)
.await
.and_then(|row| row.map(|r| r.id).ok_or(sqlx::Error::RowNotFound))
}
async fn get_by_username(&self, username: &str) -> Result<Option<User>, sqlx::Error> {
sqlx::query_as!(
User,
"SELECT id, email, username, nickname, passhash, totp_status as \"totp_status!: TotpStatus\", totp_secret, created_at, updated_at FROM users WHERE username = $1",
username
)
.fetch_optional(&self.pool)
.await
}
async fn delete_by_id(&self, id: i64) -> Result<(), sqlx::Error> {
sqlx::query!("DELETE FROM users WHERE id = $1", id)
.execute(&self.pool)
.await?;
Ok(())
}
async fn set_display_name(&self, id: i64, display_name: Option<String>) -> Result<(), sqlx::Error> {
sqlx::query!(
"UPDATE users SET nickname = $1 WHERE id = $2",
display_name,
id
)
.execute(&self.pool).await?;
Ok(())
}
async fn set_username(&self, id: i64, username: &str) -> Result<(), sqlx::Error> {
sqlx::query!(
"UPDATE users SET username = $1 WHERE id = $2",
username,
id
)
.execute(&self.pool).await?;
Ok(())
}
async fn set_twofa_enabled(&self, id: i64, enabled: &TotpStatus) -> Result<(), sqlx::Error> {
sqlx::query!(
"UPDATE users SET totp_status = $1 WHERE id = $2",
enabled as &TotpStatus,
id
)
.execute(&self.pool).await?;
Ok(())
}
async fn get_totp_secret(&self, id: i64) -> Result<Option<String>, sqlx::Error> {
sqlx::query!(
"SELECT totp_secret FROM users WHERE id = $1",
id
)
.fetch_optional(&self.pool)
.await
.map(|opt| opt.and_then(|row| row.totp_secret))
}
async fn set_totp_secret(&self, id: i64, secret: Option<String>) -> Result<(), sqlx::Error> {
sqlx::query!(
"UPDATE users SET totp_secret = $1 WHERE id = $2",
secret,
id
)
.execute(&self.pool).await?;
Ok(())
}
async fn get_pass_hash(&self, id: i64) -> Result<String, sqlx::Error> {
sqlx::query!(
"SELECT passhash FROM users WHERE id = $1",
id
)
.fetch_optional(&self.pool)
.await
.and_then(|row| row.map(|r| r.passhash).ok_or(sqlx::Error::RowNotFound))
}
async fn set_pass_hash(&self, id: i64, pass_hash: &str) -> Result<(), sqlx::Error> {
sqlx::query!(
"UPDATE users SET passhash = $1 WHERE id = $2",
pass_hash,
id
)
.execute(&self.pool).await?;
Ok(())
}
async fn set_email(&self, id: i64, email: &str) -> Result<(), sqlx::Error> {
sqlx::query!(
"UPDATE users SET email = $1 WHERE id = $2",
email,
id
)
.execute(&self.pool).await?;
Ok(())
}
async fn set_role(&self, id: i64, role: &str) -> Result<(), sqlx::Error> {
sqlx::query(
"UPDATE users SET role = $1::user_role WHERE id = $2"
)
.bind(role)
.bind(id)
.execute(&self.pool).await?;
Ok(())
}
}
+47
View File
@@ -0,0 +1,47 @@
use std::sync::Arc;
use chrono::{DateTime, Utc};
use uuid::Uuid;
use crate::error::{ApiResult, AppError};
use crate::model::auth::AccessToken;
use crate::repo::access_token_repo::AccessTokenRepo;
use crate::repo::AccessTokenRepoTrait;
#[derive(Clone)]
pub struct AccessTokenService {
repo: Arc<dyn AccessTokenRepoTrait>
}
impl AccessTokenService {
pub fn new(repo: Arc<dyn AccessTokenRepoTrait>) -> Self {
Self { repo }
}
pub async fn create(&self,
uid: i64, name: &str, max_uses: i32,
valid_from: DateTime<Utc>, valid_until: DateTime<Utc>
) -> ApiResult<String> {
if valid_from > valid_until {
return Err(AppError::bad_request("start date must be before end date"))
}
if valid_until < Utc::now() {
return Err(AppError::bad_request("expiry date must be after current date"))
}
let code = Uuid::new_v4().to_string();
self.repo.create_new(uid, name, &code, max_uses, valid_from, valid_until).await?;
Ok(code)
}
pub async fn get_valid_token(&self, token: &str) -> ApiResult<AccessToken> {
self.repo.get_code_not_expired(token).await?
.ok_or(AppError::unauthorised("invalid access token"))
}
pub async fn use_token(&self, id: i64) -> ApiResult<()> {
self.repo.use_token(id).await?;
Ok(())
}
}
+259
View File
@@ -0,0 +1,259 @@
use crate::api::auth::{Claims, TokenScope};
use crate::api::totp::totp_gen;
use crate::error::{ApiResult, AppError};
use crate::model::auth::AuthResponse;
use crate::repo::{UserRepo, AccessTokenRepoTrait};
use std::sync::Arc;
use argon2::password_hash::rand_core::OsRng;
use argon2::password_hash::SaltString;
use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
use chrono::{DateTime, Utc};
use uuid::Uuid;
use crate::api::totp::TotpStatus::{Disabled, Enabled};
use crate::svc::access_token_svc::AccessTokenService;
#[derive(Clone)]
pub struct AuthService {
users: Arc<dyn UserRepo>,
tokens: AccessTokenService,
}
impl AuthService {
pub fn new(users: Arc<dyn UserRepo>, tokens: AccessTokenService) -> Self {
Self { users, tokens }
}
pub async fn signup(&self,
email: &str, username: &str,
password: &str, access_token: &str
) -> ApiResult<AuthResponse> {
let tok_id = self.tokens.get_valid_token(access_token).await?.id;
let pass = password.to_string();
let svc = self.clone();
let hashed = tokio::task::spawn_blocking(move || svc.hash_password(&pass))
.await
.map_err(|_| AppError::internal("blocking task panicked"))??;
let uid = self.users
.new_user(email, username, &hashed).await?;
self.tokens.use_token(tok_id).await?;
let jwt = Claims::new(uid as usize, TokenScope::Full).encode();
Ok(AuthResponse {
token: jwt,
totp_required: false
})
}
pub async fn login(&self, username: &str, password: &str) -> ApiResult<AuthResponse> {
let user = self.users
.get_by_username(username).await?
.ok_or(AppError::unauthorised("invalid username"))?;
let pass = password.to_string();
let user_hash = user.passhash.clone();
let svc = self.clone();
tokio::task::spawn_blocking(move || svc.verify_password(&user_hash, &pass))
.await
.map_err(|_| AppError::internal("blocking task panicked"))??;
let scope = if user.totp_status == Enabled { TokenScope::TotpPending } else { TokenScope::Full };
let jwt = Claims::new(user.id as usize, scope).encode();
Ok(AuthResponse {
token: jwt,
totp_required: user.totp_status == Enabled
})
}
pub async fn login_totp(&self, uid: i64, code: &str) -> ApiResult<AuthResponse> {
let secret = self.users.get_totp_secret(uid).await?
.ok_or(AppError::unauthorised("2fa not enabled"))?;
self.verify_2fa(uid, &secret, code)?;
let jwt = Claims::new(uid as usize, TokenScope::Full).encode();
Ok(AuthResponse {
token: jwt,
totp_required: false
})
}
pub async fn disable_totp(&self, uid: i64, password: &str, totp_code: &str) -> ApiResult<AuthResponse> {
let mut user = self.users.get_by_id(uid).await
.ok_or(AppError::internal("user not found"))?;
let Some(secret) = user.totp_secret else {
return Err(AppError::bad_request("2fa not enabled"));
};
self.verify_password(&user.passhash, password)?;
self.verify_2fa(uid, &secret, totp_code)?;
user.totp_secret = None;
user.totp_status = Disabled;
self.users.save(&user).await?;
Ok(AuthResponse {
token: Claims::new(uid as usize, TokenScope::Full).encode(),
totp_required: false
})
}
pub async fn get_totp_status(&self, uid: i64) -> ApiResult<bool> {
Ok(
self.users.get_totp_secret(uid).await?.is_some()
)
}
pub async fn confirm_totp(&self, uid: i64, totp_code: &str) -> ApiResult<()> {
let secret = self.users.get_totp_secret(uid).await?
.ok_or(AppError::bad_request("2fa setup not initialised"))?;
self.verify_2fa(uid, &secret, totp_code)?;
self.users.set_twofa_enabled(uid, &Enabled).await?;
Ok(())
}
pub async fn get_or_create_totp_secret(
&self, uid: i64, password: &str,
) -> ApiResult<String> {
let user = self.users.get_by_id(uid).await
.ok_or(AppError::internal("user not found"))?;
let pass = password.to_string();
let user_hash = user.passhash.clone();
let svc = self.clone();
tokio::task::spawn_blocking(move || svc.verify_password(&user_hash, &pass))
.await
.map_err(|_| AppError::internal("blocking task panicked"))??;
if let Some(secret) = user.totp_secret {
return Ok(secret);
}
let new_secret = totp_rs::Secret::generate_secret()
.to_encoded()
.to_string();
self.users.set_totp_secret(uid, Some(new_secret.clone())).await?;
Ok(new_secret)
}
pub async fn verify_user_password(&self, uid: i64, password: &str) -> ApiResult<()> {
let hash = self.users.get_pass_hash(uid).await
.map_err(|_| AppError::internal("user not found"))?;
let pass = password.to_string();
let svc = self.clone();
tokio::task::spawn_blocking(move || svc.verify_password(&hash, &pass))
.await
.map_err(|_| AppError::internal("blocking task panicked"))??;
Ok(())
}
pub async fn verify_user_totp(&self, uid: i64, totp_code: &str) -> ApiResult<()> {
let secret = self.users.get_totp_secret(uid).await?
.ok_or(AppError::internal("user not found"))?;
self.verify_2fa(uid, &secret, totp_code)
}
pub fn hash_password(&self, password: &str) -> ApiResult<String> {
let salt = SaltString::generate(&mut OsRng);
Argon2::default()
.hash_password(password.as_bytes(), &salt)
.map_err(|_| AppError::internal("failed to hash password"))
.map(|hash| hash.to_string())
}
// Private helpers
fn verify_password(&self, pass_hash: &str, password: &str) -> ApiResult<()> {
let parsed_hash = PasswordHash::new(&pass_hash)
.map_err(|_| AppError::internal("invalid password hash"))?;
Argon2::default()
.verify_password(password.as_bytes(), &parsed_hash)
.map_err(|_| AppError::unauthorised("incorrect password"))?;
Ok(())
}
pub fn verify_2fa(&self, uid: i64, totp_secret: &str, totp_code: &str) -> ApiResult<()> {
if totp_gen(uid, totp_secret.as_bytes())
.map_err(|_| AppError::internal("invalid totp secret"))?
.check_current(totp_code)
.map_err(|_| AppError::internal("invalid totp code"))? {
Ok(())
} else {
Err(AppError::unauthorised("incorrect totp code"))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::repo::mock::{MockUserRepo, MockTokenRepo};
use std::sync::Mutex;
fn setup() -> AuthService {
unsafe {
std::env::set_var("JWT_SECRET", "test_secret");
}
let users = Arc::new(MockUserRepo { users: Mutex::new(vec![]) });
let tok_repo = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) });
let tokens = AccessTokenService::new(tok_repo);
AuthService::new(users, tokens)
}
#[tokio::test]
async fn test_signup_and_login() {
let auth = setup();
let code = auth.tokens.create(1, "test", 1, Utc::now(), Utc::now()).await.unwrap();
let signup_res = auth.signup("test@example.com", "tester", "password123", &code).await;
assert!(signup_res.is_ok());
let login_res = auth.login("tester", "password123").await;
assert!(login_res.is_ok());
let login_data = login_res.unwrap();
assert!(!login_data.totp_required);
assert!(!login_data.token.is_empty());
}
#[tokio::test]
async fn test_login_invalid_password() {
let auth = setup();
let token_code = auth.tokens.create(1, "test", 1, Utc::now(), Utc::now()).await.unwrap();
auth.signup("test@example.com", "tester", "password123", &token_code).await.unwrap();
let login_res = auth.login("tester", "wrong_password").await;
assert!(login_res.is_err());
if let Err(AppError::Unauthorised(msg)) = login_res {
assert_eq!(msg, "incorrect password");
} else {
panic!("Expected Unauthorised error");
}
}
#[tokio::test]
async fn test_invite() {
let auth = setup();
let res = auth.tokens.create(1, "invite", 1, Utc::now(), Utc::now() + chrono::Duration::days(1)).await;
assert!(res.is_ok());
let code = res.unwrap();
assert!(!code.is_empty());
let token = auth.tokens.get_valid_token(&code).await;
assert!(token.is_ok());
}
}
+187
View File
@@ -0,0 +1,187 @@
use crate::api::chat::ChatMsg;
use crate::error::{ApiResult, AppError};
use crate::repo::message_repo::MessageRepository;
use crate::repo::{ChannelRepo, Repo, SpaceRepo, UserRepo};
use chrono::{DateTime, Utc};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::broadcast::Sender;
use tokio::sync::{broadcast, Mutex};
use crate::model::space::SpaceDto;
use crate::svc::llm_service::LlmService;
/// ---------- shared broadcaster ----------
#[derive(Clone)]
pub struct ChatService {
users: Arc<dyn UserRepo>,
channels: Arc<dyn ChannelRepo>,
spaces: Arc<dyn SpaceRepo>,
messages: MessageRepository,
llm: LlmService,
buffer_size: usize,
senders: Arc<Mutex<HashMap<i64, Sender<ChatMsg>>>>,
}
impl ChatService {
pub fn new(
buffer_size: usize, llm: LlmService,
messages: MessageRepository, users: Arc<dyn UserRepo>,
channels: Arc<dyn ChannelRepo>, spaces: Arc<dyn SpaceRepo>,
) -> Self {
Self {
channels,
spaces,
llm,
users,
messages,
buffer_size,
senders: Arc::new(Mutex::new(std::collections::HashMap::new())),
}
}
pub async fn get_accessible_channels(&self, uid: i64) -> ApiResult<Vec<SpaceDto>> {
// let spaces = self.spaces.get_by_member(uid).await?;
// TODO! UNCOMMENT THIS ^^^^^^
let spaces = self.spaces.get_all().await?;
let mut result = Vec::new();
for space in spaces {
let channels = self.channels.get_by_space_id(space.id).await?;
result.push(SpaceDto {
channels,
id: space.id,
owner_id: space.owner_id,
name: space.name,
description: space.description,
created_at: space.created_at,
updated_at: space.updated_at,
});
}
Ok(result)
}
pub async fn get_messages(&self, channel_id: i64, limit: usize) -> ApiResult<Vec<ChatMsg>> {
let messages = self.messages.get_by_channel(channel_id, limit).await?;
Ok(messages)
}
/// Sends a chat message to the specified channel, persists it to the database,
/// and handles potential AI-generated replies asynchronously.
///
/// # Parameters
/// - `channel_id`: The ID of the channel to which the message will be sent.
/// - `uid`: The user ID of the sender.
/// - `text`: The content of the message to be sent.
/// - `created_at`: The timestamp at which the message was created.
///
/// # Returns
/// - `ApiResult<()>`: Indicates success or failure of the operation.
///
/// # Behavior
/// 1. Fetches the user by their `uid`. Returns an error if the user is not found.
/// 2. Constructs a `ChatMsg` object with the sender's `display_name` or `username`,
/// and the specified message content and timestamp.
/// 3. Publishes the constructed message to the given channel.
/// 4. Persists the message in the database.
/// 5. Spawns an asynchronous task to generate an LLM-powered (language model) reply:
/// - Sends the original message to the LLM worker for a potential reply.
/// - Publishes the LLM's reply to the same channel if successful.
/// - Persists the LLM's reply to the database.
///
/// # Notes
/// - Caching with Redis is planned for both message persistence and AI replies, but
/// is not implemented in the current version.
/// - The spawned asynchronous task does not block the main execution flow.
///
/// # Potential Errors
/// - Returns `AppError::NotFound` if the `uid` does not map to an existing user.
/// - Returns an error wrapped in `ApiResult` if the database operations fail.
///
/// # TODO
/// - Implement caching for both user-supplied messages and LLM-generated replies
/// using Redis at the repository or service layer.
pub async fn send(&self,
channel_id: i64, uid: i64,
text: &str, created_at: DateTime<Utc>
) -> ApiResult<()> {
let user = self.users.get_by_id(uid).await
.ok_or(AppError::NotFound)?;
let message = ChatMsg {
display_name: Some(user
.nickname.clone()
.unwrap_or_else(|| user.username.clone())),
user_id: uid,
text: text.to_string(),
timestamp: created_at,
};
self.publish(channel_id, message.clone()).await;
let _msg_id = self.messages.create_new(uid, channel_id, text, created_at).await?;
// TODO: caching w redis at repository layer
let svc_instance = self.clone();
let Some(text) = text.strip_prefix("/ask ") else {
return Ok(())
};
if !svc_instance.llm.enabled() {
return Ok(())
}
tokio::spawn(async move {
let response = svc_instance.llm
.query(&message)
.await;
if let Ok(reply) = response {
tracing::info!("LLM reply: {}", reply.text);
svc_instance.publish(channel_id, reply.clone()).await;
// TODO: cache response (or do with redis!)
if let Err(e) = svc_instance.messages
.create_new(reply.user_id, channel_id, &reply.text, reply.timestamp).await {
tracing::error!("Failed to persist LLM reply: {}", e);
}
tracing::info!("LLM reply persisted");
} else {
tracing::warn!("Error contacting LLM: {:?}", response);
}
});
Ok(())
}
/// Subscribe to the specified channel.
pub async fn subscribe(&self, channel_id: i64) -> broadcast::Receiver<ChatMsg> {
let mut map = self.senders.lock().await;
let sender = map
.entry(channel_id)
.or_insert_with(|| broadcast::channel::<ChatMsg>(self.buffer_size).0);
sender.subscribe()
}
// Private helper methods
/// Publish a message to the specified channel.
async fn publish(&self, channel_id: i64, msg: ChatMsg) {
let mut map = self.senders.lock().await;
let sender = map
.entry(channel_id)
.or_insert_with(|| broadcast::channel::<ChatMsg>(self.buffer_size).0);
let _ = sender.send(msg);
}
}
+89
View File
@@ -0,0 +1,89 @@
#[derive(Clone)]
pub struct LlmService;
static LMSTUDIO_URL: LazyLock<Option<String>> = LazyLock::new(|| env::var("LMSTUDIO_URL").ok());
static LMSTUDIO_MODEL: LazyLock<Option<String>> = LazyLock::new(|| env::var("LMSTUDIO_MODEL").ok());
impl LlmService {
pub fn new() -> Self {
Self {}
}
pub fn enabled(&self) -> bool {
LMSTUDIO_URL.is_some()
}
pub async fn query(&self, message: &ChatMsg) -> ApiResult<ChatMsg> {
let Some(url) = LMSTUDIO_URL.clone() else {
return Err(AppError::internal("AI not enabled!"))
};
let model = LMSTUDIO_MODEL.clone().unwrap_or_else(|| "gpt-oss-20b".into());
let client = reqwest::Client::new();
// Build the request body
let payload = LlmRequest {
model, // whatever model you run locally
messages: vec![Message {
role: "user".into(),
content: message.text.clone(),
}],
};
// POST to lmstudio (default 127.0.0.1:1234)
let resp = client
.post(url)
.json(&payload)
.send()
.await
.map_err(|_| AppError::internal("Failed to make request to LLM server"))?;
// The API returns a JSON with `choices[].message.content`
#[derive(Deserialize)]
struct LlmResponse {
choices: Vec<Choice>,
}
#[derive(Deserialize)]
struct Choice {
message: Message,
}
let llm_resp: LlmResponse = resp
.json()
.await
.map_err(|_| AppError::internal("Failed to parse LLM response"))?;
Ok(ChatMsg {
display_name: Some(String::from("llm")),
user_id: 0,
text: llm_resp.choices[0].message.content.clone(),
timestamp: chrono::Utc::now(),
})
}
}
use std::env;
use std::sync::LazyLock;
// src/llm.rs
use serde::{Deserialize, Serialize};
use crate::api::chat::ChatMsg;
use crate::error::{ApiResult, AppError};
use crate::svc::chat_svc::ChatService;
#[derive(Serialize)]
struct LlmRequest {
model: String,
messages: Vec<Message>,
}
#[derive(Serialize, Deserialize)]
struct Message {
role: String, // "user" or "assistant"
content: String,
}
+6
View File
@@ -0,0 +1,6 @@
pub mod auth_svc;
pub mod chat_svc;
pub mod settings_svc;
pub mod user_svc;
pub mod access_token_svc;
pub mod llm_service;
+116
View File
@@ -0,0 +1,116 @@
//! The `SettingsService` is responsible for managing user account settings, allowing users to
//! update their username, password, display name, email, and delete their account.
//! It interacts with the `AuthService` to handle authentication and password-related functionality
//! and the `UserRepository` to perform updates to user accounts in the data store.
use crate::error::{ApiResult, AppError};
use crate::repo::UserRepo;
use crate::svc::auth_svc::AuthService;
use std::sync::Arc;
#[derive(Clone)]
pub struct SettingsService {
auth: AuthService,
users: Arc<dyn UserRepo>,
}
impl SettingsService {
pub fn new(auth: AuthService, users: Arc<dyn UserRepo>) -> Self {
Self { auth, users }
}
pub async fn change_username(&self, uid: i64, new: &str) -> ApiResult<()> {
self.users.set_username(uid, new).await?;
Ok(())
}
pub async fn change_password(&self, uid: i64, old: &str, new: &str) -> ApiResult<()> {
self.auth.verify_user_password(uid, old).await?;
let hashed = self.auth.hash_password(new)?;
self.users.set_pass_hash(uid, &hashed).await?;
Ok(())
}
pub async fn change_display_name(&self, uid: i64, new: Option<String>) -> ApiResult<()> {
self.users.set_display_name(uid, new).await?;
Ok(())
}
pub async fn change_email(&self, uid: i64, new: &str) -> ApiResult<()> {
self.users.set_email(uid, new).await?;
Ok(())
}
pub async fn delete_account(&self, uid: i64, password: &str, totp_code: &Option<String>) -> ApiResult<()> {
self.auth.verify_user_password(uid, password).await?;
// check 2fa code is correct if enabled
if self.auth.get_totp_status(uid).await? {
let Some(totp_code) = totp_code else {
return Err(AppError::unauthorised("2fa code is required"))
};
self.auth.verify_user_totp(uid, totp_code).await?;
}
self.users.delete_by_id(uid).await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::repo::mock::{MockUserRepo, MockTokenRepo};
use std::sync::Mutex;
use chrono::Utc;
use crate::svc::access_token_svc::AccessTokenService;
fn setup() -> SettingsService {
unsafe {
std::env::set_var("JWT_SECRET", "test_secret");
}
let users = Arc::new(MockUserRepo { users: Mutex::new(vec![]) });
let token_repo = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) });
let tokens_svc = AccessTokenService::new(token_repo);
let auth = AuthService::new(users.clone(), tokens_svc.clone());
SettingsService::new(auth, users)
}
#[tokio::test]
async fn test_change_username() {
let settings = setup();
let uid = settings.users.new_user("test@example.com", "old", "pass").await.unwrap();
settings.change_username(uid, "new").await.unwrap();
let user = settings.users.get_by_id(uid).await.unwrap();
assert_eq!(user.username, "new");
}
#[tokio::test]
async fn test_change_password() {
let settings = setup();
let pass = "old_pass";
let hashed = settings.auth.hash_password(pass).unwrap();
let uid = settings.users.new_user("test@example.com", "user", &hashed).await.unwrap();
settings.change_password(uid, pass, "new_pass").await.unwrap();
let _user = settings.users.get_by_id(uid).await.unwrap();
assert!(settings.auth.verify_user_password(uid, "new_pass").await.is_ok());
}
#[tokio::test]
async fn test_delete_account() {
let settings = setup();
let pass = "password";
let hashed = settings.auth.hash_password(pass).unwrap();
let uid = settings.users.new_user("test@example.com", "user", &hashed).await.unwrap();
let res = settings.delete_account(uid, pass, &None).await;
assert!(res.is_ok());
let user = settings.users.get_by_id(uid).await;
assert!(user.is_none());
}
}
+28
View File
@@ -0,0 +1,28 @@
use crate::error::ApiResult;
use crate::repo::UserRepo;
use std::sync::Arc;
pub struct UserService {
repo: Arc<dyn UserRepo>
}
impl UserService {
pub fn new(repo: Arc<dyn UserRepo>) -> Self {
Self { repo }
}
pub async fn get_display_name(&self, uid: i64) -> ApiResult<String> {
// TODO: redis caching for display names
let user = self.repo.get_by_id(uid)
.await.ok_or(crate::error::AppError::NotFound)?;
Ok(user.nickname.unwrap_or_else(|| user.username))
}
pub async fn get_username(&self, uid: i64) -> ApiResult<String> {
self.repo.get_by_id(uid)
.await.ok_or(crate::error::AppError::NotFound)
.map(|u| u.username)
}
}
-188
View File
@@ -1,188 +0,0 @@
use argon2::{Argon2, PasswordHash, PasswordVerifier};
use redis::AsyncCommands;
use rocket::{http::Status, serde::json::Json, time::OffsetDateTime};
use rocket_db_pools::Connection;
use crate::{
auth::{Session, two_factor::totp_gen},
db::{Postgres, Redis},
};
pub struct User {
pub id: i32,
pub email: Option<String>,
pub username: String,
pub display_name: Option<String>,
pub pass_hash: String,
pub twofa_enabled: bool,
pub totp_secret: Option<String>,
pub created_at: Option<OffsetDateTime>,
pub updated_at: Option<OffsetDateTime>,
}
impl User {
pub async fn get_by_id(id: usize, db: &mut Connection<Postgres>) -> Option<Self> {
sqlx::query_as!(
Self,
"SELECT id, email, username, display_name, pass_hash, twofa_enabled, totp_secret, created_at, updated_at FROM users WHERE id = $1",
id as i32
)
.fetch_optional(&mut ***db)
.await
.unwrap_or(None)
}
pub async fn delete(&mut self, db: &mut Connection<Postgres>) -> Result<(), sqlx::Error> {
sqlx::query!("DELETE FROM users WHERE id = $1", self.id)
.execute(&mut ***db)
.await?;
Ok(())
}
pub fn verify_2fa(&self, code: &str) -> Result<(), Status> {
if totp_gen(
self.id as usize,
self.totp_secret
.clone()
.expect("user with 2fa enabled has no totp secret")
.as_bytes(),
)
.map_err(|_| Status::InternalServerError)?
.check_current(code)
.map_err(|_| Status::InternalServerError)?
{
Ok(())
} else {
Err(Status::Unauthorized)
}
}
pub fn verify_password(&self, password: &str) -> Result<(), Status> {
let parsed_hash = PasswordHash::new(&self.pass_hash)
.inspect_err(|e| {
tracing::error!("Failed to parse hash for password! uid:{} {e}", self.id)
})
.map_err(|_| Status::InternalServerError)?;
Argon2::default()
.verify_password(password.as_bytes(), &parsed_hash)
.map_err(|_| Status::Unauthorized)
}
pub async fn set_display_name(
&mut self,
display_name: Option<String>,
db: &mut Connection<Postgres>,
) -> Result<(), sqlx::Error> {
self.display_name = display_name;
sqlx::query!(
"UPDATE users SET display_name = $1 WHERE id = $2",
self.display_name,
self.id
)
.execute(&mut ***db)
.await?;
Ok(())
}
pub async fn set_username(
&mut self,
username: String,
db: &mut Connection<Postgres>,
) -> Result<(), sqlx::Error> {
self.username = username;
sqlx::query!(
"UPDATE users SET username = $1 WHERE id = $2",
self.username,
self.id
)
.execute(&mut ***db)
.await?;
Ok(())
}
pub async fn set_twofa_enabled(
&mut self,
enabled: bool,
db: &mut Connection<Postgres>,
) -> Result<(), sqlx::Error> {
self.twofa_enabled = enabled;
sqlx::query!(
"UPDATE users SET twofa_enabled = $1 WHERE id = $2",
self.twofa_enabled,
self.id
)
.execute(&mut ***db)
.await?;
Ok(())
}
pub async fn set_pass_hash(
&mut self,
pass_hash: String,
db: &mut Connection<Postgres>,
) -> Result<(), sqlx::Error> {
self.pass_hash = pass_hash;
sqlx::query!(
"UPDATE users SET pass_hash = $1 WHERE id = $2",
self.pass_hash,
self.id
)
.execute(&mut ***db)
.await?;
Ok(())
}
}
#[get("/users", rank = 2)]
pub async fn users(_ag: Session, mut db: Connection<Postgres>) -> Json<Vec<i32>> {
sqlx::query!("SELECT id FROM users")
.fetch_all(&mut **db)
.await
.unwrap_or_else(|_| Vec::new())
.into_iter()
.map(|row| row.id)
.collect::<Vec<i32>>()
.into()
}
#[get("/users/<id>", rank = 1)]
pub async fn display_name(
id: usize,
_ag: Session,
mut pgsql_conn: Connection<Postgres>,
mut redis_conn: Connection<Redis>,
) -> String {
UserCache::username(id, &mut redis_conn, &mut pgsql_conn).await
}
pub struct UserCache {}
impl UserCache {
pub async fn username(
id: usize,
redis_conn: &mut Connection<Redis>,
pgsql_conn: &mut Connection<Postgres>,
) -> String {
if let Ok(val) = redis_conn.get(format!("users:{id}")).await {
return val;
}
if let Ok(v) = sqlx::query!("SELECT username FROM users WHERE id = $1", id as i32)
.fetch_one(&mut ***pgsql_conn)
.await
{
let username = v.username;
Self::insert(id, &username, redis_conn).await;
username
} else {
unimplemented!()
}
}
pub async fn insert(id: usize, username: &str, conn: &mut Connection<Redis>) {
conn.set_ex::<_, _, ()>(format!("users:{id}"), username.to_string(), 1800)
.await
.expect("failed to insert key");
}
}
+198
View File
@@ -0,0 +1,198 @@
use backend::rocket_builder;
use backend::repo::mock::{MockUserRepo, MockTokenRepo};
use backend::repo::message_repo::MessageRepository;
use backend::svc::chat_svc::ChatService;
use backend::repo::user_repo::UserRepository;
use backend::repo::{Repo, AccessTokenRepoTrait};
use rocket::local::asynchronous::Client;
use rocket::http::{Status, ContentType};
use serde_json::json;
use std::sync::{Arc, Mutex};
use sqlx::PgPool;
use chrono::Utc;
use backend::svc::llm_service::LlmService;
async fn test_rocket() -> rocket::Rocket<rocket::Build> {
let users = Arc::new(MockUserRepo { users: Mutex::new(vec![]) });
let tokens = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) });
let pool = PgPool::connect_lazy("postgres://localhost/unused").unwrap();
let messages = MessageRepository::new(pool.clone());
let user_repo = Arc::new(UserRepository::new(pool));
let llm_service = LlmService::new();
let chat_service = ChatService::new(32, messages, user_repo, llm_service);
rocket_builder(users, tokens, chat_service)
}
#[rocket::async_test]
async fn test_unauthorized_access() {
let client = Client::tracked(test_rocket().await).await.expect("valid rocket instance");
// Attempt to access a protected endpoint without authentication
let response = client.patch("/api/settings/display_name").dispatch().await;
assert_eq!(response.status(), Status::Unauthorized);
let response = client.post("/api/settings/password").dispatch().await;
assert_eq!(response.status(), Status::Unauthorized);
let response = client.delete("/api/settings").dispatch().await;
assert_eq!(response.status(), Status::Unauthorized);
}
#[rocket::async_test]
async fn test_signup_invalid_token() {
let client = Client::tracked(test_rocket().await).await.expect("valid rocket instance");
let signup_data = json!({
"email": "test@example.com",
"username": "testuser",
"password": "password123",
"access_token": "invalid-token"
});
let response = client.post("/api/signup")
.header(ContentType::JSON)
.body(signup_data.to_string())
.dispatch()
.await;
assert_eq!(response.status(), Status::Unauthorized);
}
#[rocket::async_test]
async fn test_login_invalid_credentials() {
let client = Client::tracked(test_rocket().await).await.expect("valid rocket instance");
let login_data = json!({
"username": "nonexistent",
"password": "wrongpassword"
});
let response = client.post("/api/login")
.header(ContentType::JSON)
.body(login_data.to_string())
.dispatch()
.await;
assert_eq!(response.status(), Status::Unauthorized);
}
#[rocket::async_test]
async fn test_full_auth_flow() {
let users = Arc::new(MockUserRepo { users: Mutex::new(vec![]) });
let tokens = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) });
let pool = PgPool::connect_lazy("postgres://localhost/unused").unwrap();
let messages = MessageRepository::new(pool.clone());
let user_repo = Arc::new(UserRepository::new(pool));
let llm_service = LlmService::new();
let chat_service = ChatService::new(32, messages, user_repo, llm_service);
let token_code = "valid-token";
tokens.create_new(1, "test", token_code, 1, Utc::now(), Utc::now()).await.unwrap();
let client = Client::tracked(rocket_builder(users, tokens, chat_service)).await.expect("valid rocket instance");
// 1. Signup
let signup_data = json!({
"email": "test@example.com",
"username": "testuser",
"password": "password123",
"access_token": token_code
});
let response = client.post("/api/signup")
.header(ContentType::JSON)
.body(signup_data.to_string())
.dispatch()
.await;
assert_eq!(response.status(), Status::Ok);
let body = response.into_string().await.unwrap();
assert!(body.contains("token"));
// 2. Login
let login_data = json!({
"username": "testuser",
"password": "password123"
});
let response = client.post("/api/login")
.header(ContentType::JSON)
.body(login_data.to_string())
.dispatch()
.await;
assert_eq!(response.status(), Status::Ok);
let body = response.into_string().await.unwrap();
assert!(body.contains("token"));
}
#[rocket::async_test]
async fn test_delete_account_security() {
let users = Arc::new(MockUserRepo { users: Mutex::new(vec![]) });
let tokens = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) });
let pool = PgPool::connect_lazy("postgres://localhost/unused").unwrap();
let messages = MessageRepository::new(pool.clone());
let user_repo = Arc::new(UserRepository::new(pool));
let llm_service = LlmService::new();
let chat_service = ChatService::new(32, messages, user_repo, llm_service);
let client = Client::tracked(rocket_builder(users.clone(), tokens.clone(), chat_service)).await.expect("valid rocket instance");
let token_code = "valid-token";
tokens.create_new(1, "test", token_code, 1, Utc::now(), Utc::now()).await.unwrap();
client.post("/api/signup")
.header(ContentType::JSON)
.body(json!({
"email": "test@example.com",
"username": "testuser",
"password": "password123",
"access_token": token_code
}).to_string())
.dispatch()
.await;
// Login to get JWT
let login_res = client.post("/api/login")
.header(ContentType::JSON)
.body(json!({
"username": "testuser",
"password": "password123"
}).to_string())
.dispatch()
.await;
let auth_resp: serde_json::Value = serde_json::from_str(&login_res.into_string().await.unwrap()).unwrap();
let jwt = auth_resp["token"].as_str().unwrap();
// 1. Delete with WRONG password
let response = client.delete("/api/settings")
.header(ContentType::JSON)
.header(rocket::http::Header::new("Authorization", format!("Bearer {}", jwt)))
.body(json!({
"password": "wrongpassword",
"totp_code": null
}).to_string())
.dispatch()
.await;
assert_eq!(response.status(), Status::Unauthorized);
// 2. Delete with CORRECT password
let response = client.delete("/api/settings")
.header(ContentType::JSON)
.header(rocket::http::Header::new("Authorization", format!("Bearer {}", jwt)))
.body(json!({
"password": "password123",
"totp_code": null
}).to_string())
.dispatch()
.await;
assert_eq!(response.status(), Status::Ok);
// Verify user is gone
assert!(users.users.lock().unwrap().is_empty());
}
+142
View File
@@ -0,0 +1,142 @@
use backend::rocket_builder;
use backend::repo::mock::{MockUserRepo, MockTokenRepo};
use backend::repo::message_repo::MessageRepository;
use backend::svc::chat_svc::ChatService;
use backend::repo::{Repo, AccessTokenRepoTrait};
use rocket::local::asynchronous::Client;
use rocket::http::{Status, ContentType, Header};
use serde_json::{json, Value};
use std::sync::{Arc, Mutex};
use sqlx::PgPool;
use chrono::Utc;
use backend::svc::llm_service::LlmService;
async fn setup_client_with_svc(chat_service: ChatService, users: Arc<MockUserRepo>, tokens: Arc<MockTokenRepo>) -> (Client, String) {
let client = Client::tracked(rocket_builder(users.clone(), tokens.clone(), chat_service)).await.expect("valid rocket instance");
// Create a user and get JWT
let token_code = "valid-token";
tokens.create_new(1, "test", token_code, 1, Utc::now(), Utc::now()).await.unwrap();
let jwt = {
let signup_res = client.post("/api/signup")
.header(ContentType::JSON)
.body(json!({
"email": "test@example.com",
"username": "testuser",
"password": "password123",
"access_token": token_code
}).to_string())
.dispatch()
.await;
assert_eq!(signup_res.status(), Status::Ok);
let login_res = client.post("/api/login")
.header(ContentType::JSON)
.body(json!({
"username": "testuser",
"password": "password123"
}).to_string())
.dispatch()
.await;
assert_eq!(login_res.status(), Status::Ok, "Login failed");
let body = login_res.into_string().await.expect("login body");
let auth_resp: serde_json::Value = serde_json::from_str(&body).unwrap();
auth_resp["token"].as_str().unwrap().to_string()
};
(client, jwt)
}
#[rocket::async_test]
async fn test_chat_event_stream_consistency() {
unsafe { std::env::set_var("JWT_SECRET", "test_secret"); }
let pool = PgPool::connect_lazy("postgres://localhost/unused").unwrap();
let messages = <MessageRepository as Repo>::new(pool.clone());
let users_repo = Arc::new(MockUserRepo { users: Mutex::new(vec![]) });
let tokens_repo = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) });
let llm_service = LlmService::new();
let chat_service = ChatService::new(1024, messages, users_repo.clone(), llm_service);
let (client, jwt) = setup_client_with_svc(chat_service.clone(), users_repo.clone(), tokens_repo.clone()).await;
// Use the same client for sender but with a different user (or the same, doesn't matter for broadcast)
// Actually, to simulate another user, we should sign up another user.
let jwt_sender = {
let token_code = "valid-token-2";
tokens_repo.create_new(1, "test2", token_code, 1, Utc::now(), Utc::now() + chrono::Duration::days(1)).await.unwrap();
let signup_res = client.post("/api/signup")
.header(ContentType::JSON)
.body(json!({
"email": "test2@example.com",
"username": "testuser2",
"password": "password123",
"access_token": token_code
}).to_string())
.dispatch()
.await;
assert_eq!(signup_res.status(), Status::Ok);
let login_res = client.post("/api/login")
.header(ContentType::JSON)
.body(json!({
"username": "testuser2",
"password": "password123"
}).to_string())
.dispatch()
.await;
let body = login_res.into_string().await.unwrap();
let auth_resp: serde_json::Value = serde_json::from_str(&body).unwrap();
auth_resp["token"].as_str().unwrap().to_string()
};
let channel_id = 1;
// Start listening to the event stream
let mut response = client.get(format!("/api/events/{}", channel_id))
.header(Header::new("Authorization", format!("Bearer {}", jwt)))
.dispatch()
.await;
assert_eq!(response.status(), Status::Ok);
let num_messages = 5; // Reduced for faster debugging
let mut received_count = 0;
let jwt_clone = jwt.clone();
tokio::spawn(async move {
for i in 0..num_messages {
let msg = format!("Message {}", i);
let res = sender_client.post(format!("/api/chat/{}", channel_id))
.header(ContentType::JSON)
.header(Header::new("Authorization", format!("Bearer {}", jwt_clone)))
.body(json!({
"display_name": "testuser",
"user_id": 1,
"text": msg,
"timestamp": Utc::now()
}).to_string())
.dispatch()
.await;
assert_eq!(res.status(), Status::Ok);
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
}
});
// Wait a bit for messages to be posted
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
// Consume the stream
let text = response.into_string().await.unwrap();
println!("Received chunk: {}", text);
let mut received_count = 0;
for line in text.lines() {
if line.starts_with("data:") {
received_count += 1;
}
}
assert_eq!(received_count, num_messages, "Should receive all posted messages. Received: {}. Full text: {}", received_count, text);
}
+121
View File
@@ -0,0 +1,121 @@
use backend::rocket_builder;
use backend::repo::mock::{MockUserRepo, MockTokenRepo};
use backend::repo::message_repo::MessageRepository;
use backend::svc::chat_svc::ChatService;
use backend::repo::user_repo::UserRepository;
use backend::repo::{Repo, AccessTokenRepoTrait};
use rocket::local::asynchronous::Client;
use rocket::http::{Status, ContentType, Header};
use serde_json::json;
use std::sync::{Arc, Mutex};
use sqlx::PgPool;
use chrono::Utc;
use backend::svc::llm_service::LlmService;
async fn setup_client() -> (Client, Arc<MockUserRepo>, String) {
let users = Arc::new(MockUserRepo { users: Mutex::new(vec![]) });
let tokens = Arc::new(MockTokenRepo { tokens: Mutex::new(vec![]) });
let pool = PgPool::connect_lazy("postgres://localhost/unused").unwrap();
let messages = MessageRepository::new(pool.clone());
let user_repo = Arc::new(UserRepository::new(pool));
let llm_service = LlmService::new();
let chat_service = ChatService::new(32, messages, user_repo, llm_service);
let client = Client::tracked(rocket_builder(users.clone(), tokens.clone(), chat_service)).await.expect("valid rocket instance");
// Create a user and get JWT
let token_code = "valid-token";
tokens.create_new(1, "test", token_code, 1, Utc::now(), Utc::now()).await.unwrap();
client.post("/api/signup")
.header(ContentType::JSON)
.body(json!({
"email": "test@example.com",
"username": "testuser",
"password": "password123",
"access_token": token_code
}).to_string())
.dispatch()
.await;
let login_res = client.post("/api/login")
.header(ContentType::JSON)
.body(json!({
"username": "testuser",
"password": "password123"
}).to_string())
.dispatch()
.await;
let auth_resp: serde_json::Value = serde_json::from_str(&login_res.into_string().await.unwrap()).unwrap();
let jwt = auth_resp["token"].as_str().unwrap().to_string();
(client, users, jwt)
}
#[rocket::async_test]
async fn test_change_display_name() {
let (client, users, jwt) = setup_client().await;
let response = client.patch("/api/settings/display_name")
.header(ContentType::JSON)
.header(Header::new("Authorization", format!("Bearer {}", jwt)))
.body(json!({
"display_name": "New Display Name"
}).to_string())
.dispatch()
.await;
assert_eq!(response.status(), Status::Ok);
let user = users.users.lock().unwrap()[0].clone();
assert_eq!(user.nickname, Some("New Display Name".to_string()));
}
#[rocket::async_test]
async fn test_change_username() {
let (client, users, jwt) = setup_client().await;
let response = client.patch("/api/settings/username")
.header(ContentType::JSON)
.header(Header::new("Authorization", format!("Bearer {}", jwt)))
.body(json!({
"username": "newusername"
}).to_string())
.dispatch()
.await;
assert_eq!(response.status(), Status::Ok);
let user = users.users.lock().unwrap()[0].clone();
assert_eq!(user.username, "newusername");
}
#[rocket::async_test]
async fn test_change_password() {
let (client, _, jwt) = setup_client().await;
let response = client.post("/api/settings/password")
.header(ContentType::JSON)
.header(Header::new("Authorization", format!("Bearer {}", jwt)))
.body(json!({
"old_password": "password123",
"new_password": "newpassword456"
}).to_string())
.dispatch()
.await;
assert_eq!(response.status(), Status::Ok);
// Verify login with new password
let login_res = client.post("/api/login")
.header(ContentType::JSON)
.body(json!({
"username": "testuser",
"password": "newpassword456"
}).to_string())
.dispatch()
.await;
assert_eq!(login_res.status(), Status::Ok);
}