11 Commits

Author SHA1 Message Date
zxq5 7e001d8769 idk 2026-06-03 19:12:23 +01:00
zxq5 2f34976f3e Merge remote-tracking branch 'origin/dev' into dev 2026-04-11 00:10:09 +01:00
zxq5 d1208f7e39 frontend v0.4.1-2
- added invite section to UI and some general bug fixes
2026-04-11 00:09:47 +01:00
zxq5 d6ba875297 addedd RELEASE_MODE=1 to run var to prevent crash in absence of .env
file
2026-04-08 00:05:54 +01:00
zxq5 529d09aabc frontend v0.4.1
- fixed most of the bugs with the rewrite. should be ready to deploy now
2026-04-08 00:00:28 +01:00
zxq5 5291e7dee6 rewritten docker compose files and updated giutignore 2026-04-06 15:42:20 +01:00
zxq5 3c52ade946 deleted some old files 2026-04-06 15:38:28 +01:00
zxq5 0f692e4372 updated docker compose and formatted backend. 2026-04-06 13:44:50 +01:00
zxq5 d33eee1281 frontend v0.4.0 2026-04-06 01:02:39 +01:00
zxq5 7c9b733813 updated gitignore 2026-04-06 01:00:27 +01:00
zxq5 bda1ef251a full backend rewrite.
calling this v0.4.0
2026-04-06 00:57:23 +01:00
799 changed files with 4768 additions and 94668 deletions
+2 -1
View File
@@ -1,6 +1,7 @@
*/target */target
.env .env
.log* .log*
Cargo.lock Cargo.lock
.cargo/ .cargo/
docker-compose* .sqlx/
+4 -1
View File
@@ -2,7 +2,10 @@
<module type="JAVA_MODULE" version="4"> <module type="JAVA_MODULE" version="4">
<component name="NewModuleRootManager" inherit-compiler-output="true"> <component name="NewModuleRootManager" inherit-compiler-output="true">
<exclude-output /> <exclude-output />
<content url="file://$MODULE_DIR$" /> <content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/backend/src" isTestSource="false" />
<excludeFolder url="file://$MODULE_DIR$/backend/target" />
</content>
<orderEntry type="inheritedJdk" /> <orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" /> <orderEntry type="sourceFolder" forTests="false" />
</component> </component>
+17
View File
@@ -0,0 +1,17 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
<data-source source="LOCAL" name="chatapp dev" uuid="81992477-fd6f-427e-a27e-7378c26db6ef">
<driver-ref>postgresql</driver-ref>
<synchronize>true</synchronize>
<jdbc-driver>org.postgresql.Driver</jdbc-driver>
<jdbc-url>jdbc:postgresql://100.118.108.58:5432/chatapp_dev</jdbc-url>
<jdbc-additional-properties>
<property name="com.intellij.clouds.kubernetes.db.host.port" />
<property name="com.intellij.clouds.kubernetes.db.enabled" value="false" />
<property name="com.intellij.clouds.kubernetes.db.container.port" />
</jdbc-additional-properties>
<working-dir>$ProjectFileDir$</working-dir>
</data-source>
</component>
</project>
+11
View File
@@ -0,0 +1,11 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="GradleMigrationSettings" migrationVersion="1" />
<component name="GradleSettings">
<option name="linkedExternalProjectsSettings">
<GradleProjectSettings>
<option name="externalProjectPath" value="$PROJECT_DIR$/android" />
</GradleProjectSettings>
</option>
</component>
</project>
+1
View File
@@ -1,4 +1,5 @@
<project version="4"> <project version="4">
<component name="ExternalStorageConfigurationManager" enabled="true" />
<component name="ProjectRootManager" version="2" project-jdk-name="25" project-jdk-type="JavaSDK"> <component name="ProjectRootManager" version="2" project-jdk-name="25" project-jdk-type="JavaSDK">
<output url="file://$PROJECT_DIR$/out" /> <output url="file://$PROJECT_DIR$/out" />
</component> </component>
+3
View File
@@ -1,7 +1,9 @@
*.iml *.iml
.gradle .gradle
/local.properties /local.properties
/keystore.properties
/.idea/caches /.idea/caches
/.idea/.cache
/.idea/libraries /.idea/libraries
/.idea/modules.xml /.idea/modules.xml
/.idea/workspace.xml /.idea/workspace.xml
@@ -13,3 +15,4 @@
.externalNativeBuild .externalNativeBuild
.cxx .cxx
local.properties local.properties
release/
+8
View File
@@ -4,6 +4,14 @@
<selectionStates> <selectionStates>
<SelectionState runConfigName="app"> <SelectionState runConfigName="app">
<option name="selectionMode" value="DROPDOWN" /> <option name="selectionMode" value="DROPDOWN" />
<DropdownSelection timestamp="2026-04-02T14:33:39.814557661Z">
<Target type="DEFAULT_BOOT">
<handle>
<DeviceId pluginId="PhysicalDevice" identifier="serial=00319362N000094" />
</handle>
</Target>
</DropdownSelection>
<DialogSelection />
</SelectionState> </SelectionState>
<SelectionState runConfigName="MainActivity"> <SelectionState runConfigName="MainActivity">
<option name="selectionMode" value="DROPDOWN" /> <option name="selectionMode" value="DROPDOWN" />
+34 -3
View File
@@ -1,3 +1,5 @@
import java.util.Properties
plugins { plugins {
alias(libs.plugins.android.application) alias(libs.plugins.android.application)
alias(libs.plugins.kotlin.compose) alias(libs.plugins.kotlin.compose)
@@ -8,6 +10,25 @@ android {
namespace = "dev.zxq5.chatapp.android" namespace = "dev.zxq5.chatapp.android"
compileSdk = 35 compileSdk = 35
val keystorePropertiesFile = rootProject.file("local.properties")
val keystoreProperties = Properties()
if (keystorePropertiesFile.exists()) {
keystoreProperties.load(keystorePropertiesFile.inputStream())
}
signingConfigs {
create("release") {
storeFile = file("${System.getProperty("user.home")}/keystores/chatapp.jks")
storePassword = keystoreProperties["KEYSTORE_PASSWORD"] as String?
?: System.getenv("KEYSTORE_PASSWORD")
?: ""
keyAlias = "chatapp"
keyPassword = keystoreProperties["KEY_PASSWORD"] as String?
?: System.getenv("KEY_PASSWORD")
?: ""
}
}
defaultConfig { defaultConfig {
applicationId = "dev.zxq5.chatapp.android" applicationId = "dev.zxq5.chatapp.android"
minSdk = 26 minSdk = 26
@@ -20,19 +41,30 @@ android {
buildTypes { buildTypes {
release { release {
isMinifyEnabled = false isMinifyEnabled = true // shrinks code
isShrinkResources = true // removes unused resources
signingConfig = signingConfigs.getByName("release")
proguardFiles( proguardFiles(
getDefaultProguardFile("proguard-android-optimize.txt"), getDefaultProguardFile("proguard-android-optimize.txt"),
"proguard-rules.pro" "proguard-rules.pro"
) )
buildConfigField("String", "BASE_URL", "\"https://chat.zxq5.dev\"")
}
debug {
isMinifyEnabled = false
isDebuggable = true
applicationIdSuffix = ".debug" // lets you install both side by side
buildConfigField("String", "BASE_URL", "\"http://zxq5-x1:8000\"")
} }
} }
compileOptions { compileOptions {
sourceCompatibility = JavaVersion.VERSION_11 sourceCompatibility = JavaVersion.VERSION_11
targetCompatibility = JavaVersion.VERSION_11 targetCompatibility = JavaVersion.VERSION_11
} }
buildFeatures { buildFeatures {
compose = true compose = true
buildConfig = true
} }
} }
@@ -44,7 +76,6 @@ dependencies {
implementation(libs.ktor.client.auth) // Auth plugin implementation(libs.ktor.client.auth) // Auth plugin
// Kotlinx Serialization // Kotlinx Serialization
implementation(libs.kotlinx.serialization.json) implementation(libs.kotlinx.serialization.json)
// Coroutines // Coroutines
implementation(libs.kotlinx.coroutines.android) implementation(libs.kotlinx.coroutines.android)
@@ -73,4 +104,4 @@ dependencies {
androidTestImplementation(libs.androidx.compose.ui.test.junit4) androidTestImplementation(libs.androidx.compose.ui.test.junit4)
debugImplementation(libs.androidx.compose.ui.tooling) debugImplementation(libs.androidx.compose.ui.tooling)
debugImplementation(libs.androidx.compose.ui.test.manifest) debugImplementation(libs.androidx.compose.ui.test.manifest)
} }
+27 -1
View File
@@ -18,4 +18,30 @@
# If you keep the line number information, uncomment this to # If you keep the line number information, uncomment this to
# hide the original source file name. # hide the original source file name.
#-renamesourcefileattribute SourceFile #-renamesourcefileattribute SourceFile
# Ktor
-keep class io.ktor.** { *; }
-keep class kotlinx.coroutines.** { *; }
# Kotlinx serialization
-keepattributes *Annotation*, InnerClasses
-dontnote kotlinx.serialization.AnnotationsKt
-keep,includedescriptorclasses class dev.zxq5.chatapp.android.**$$serializer { *; }
-keepclassmembers class dev.zxq5.chatapp.android.** {
*** Companion;
}
-keepclasseswithmembers class dev.zxq5.chatapp.android.** {
kotlinx.serialization.KSerializer serializer(...);
}
# Keep model classes (serialization needs these)
-keep class dev.zxq5.chatapp.android.api.model.** { *; }
-keep class dev.zxq5.chatapp.android.data.model.** { *; }
# Fix for missing errorprone and javax annotations used by Tink and other libraries
-dontwarn com.google.errorprone.annotations.**
-dontwarn javax.annotation.**
# Fix for missing java.lang.management referenced by Ktor (not available on Android)
-dontwarn java.lang.management.**
+7 -1
View File
@@ -1,8 +1,9 @@
<?xml version="1.0" encoding="utf-8"?> <?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android" <manifest xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"> xmlns:tools="http://tools.android.com/tools">
<uses-permission android:name="android.permission.INTERNET" /> <uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.POST_NOTIFICATIONS"/>
<application <application
android:name=".ChatApplication" android:name=".ChatApplication"
@@ -15,6 +16,11 @@
android:supportsRtl="true" android:supportsRtl="true"
android:theme="@style/Theme.Chatapp" android:theme="@style/Theme.Chatapp"
android:usesCleartextTraffic="true"> android:usesCleartextTraffic="true">
<service
android:name=".core.service.MessageStreamService"
android:exported="false"/>
<activity <activity
android:name=".MainActivity" android:name=".MainActivity"
android:exported="true" android:exported="true"
@@ -1,6 +1,9 @@
package dev.zxq5.chatapp.android package dev.zxq5.chatapp.android
import android.app.Application import android.app.Application
import android.app.NotificationChannel
import android.app.NotificationManager
import android.os.Build
import dev.zxq5.chatapp.android.core.data.TokenStore import dev.zxq5.chatapp.android.core.data.TokenStore
import dev.zxq5.chatapp.android.data.repository.AuthRepository import dev.zxq5.chatapp.android.data.repository.AuthRepository
import dev.zxq5.chatapp.android.data.repository.ChatRepository import dev.zxq5.chatapp.android.data.repository.ChatRepository
@@ -8,6 +11,10 @@ import dev.zxq5.chatapp.android.data.repository.SettingsRepository
class ChatApplication : Application() { class ChatApplication : Application() {
object AppState {
var isInForeground = false
}
val tokenStore by lazy { TokenStore(this) } val tokenStore by lazy { TokenStore(this) }
val authRepository by lazy { AuthRepository(tokenStore) } val authRepository by lazy { AuthRepository(tokenStore) }
val chatRepository by lazy { ChatRepository(tokenStore) } val chatRepository by lazy { ChatRepository(tokenStore) }
@@ -15,5 +22,30 @@ class ChatApplication : Application() {
override fun onCreate() { override fun onCreate() {
super.onCreate() super.onCreate()
createNotificationChannels()
}
private fun createNotificationChannels() {
val messageChannel = NotificationChannel(
"messages",
"Messages",
NotificationManager.IMPORTANCE_HIGH
).apply {
description = "New message notifications"
enableVibration(true)
}
// add this — required for the foreground service persistent notification
val serviceChannel = NotificationChannel(
"service",
"Background connection",
NotificationManager.IMPORTANCE_LOW
).apply {
description = "Keeps messages running in background"
}
val manager = getSystemService(NotificationManager::class.java)
manager.createNotificationChannel(messageChannel)
manager.createNotificationChannel(serviceChannel)
} }
} }
@@ -1,28 +1,49 @@
package dev.zxq5.chatapp.android package dev.zxq5.chatapp.android
import android.Manifest
import android.os.Build
import android.os.Bundle import android.os.Bundle
import androidx.activity.ComponentActivity import androidx.activity.ComponentActivity
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.compose.setContent import androidx.activity.compose.setContent
import androidx.activity.enableEdgeToEdge import androidx.activity.enableEdgeToEdge
import androidx.activity.result.contract.ActivityResultContracts
import androidx.compose.foundation.border
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.outlined.ChatBubbleOutline
import androidx.compose.material.icons.outlined.PeopleOutline
import androidx.compose.material.icons.outlined.Settings
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.NavigationBar
import androidx.compose.material3.NavigationBarItem
import androidx.compose.material3.NavigationBarItemDefaults
import androidx.compose.material3.Scaffold import androidx.compose.material3.Scaffold
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.unit.dp
import androidx.lifecycle.viewmodel.compose.viewModel import androidx.lifecycle.viewmodel.compose.viewModel
import dev.zxq5.chatapp.android.core.data.TokenStore import dev.zxq5.chatapp.android.ChatApplication.AppState
import dev.zxq5.chatapp.android.data.repository.AuthRepository import dev.zxq5.chatapp.android.core.service.MessageStreamService
import dev.zxq5.chatapp.android.data.repository.AuthState import dev.zxq5.chatapp.android.data.repository.AuthState
import dev.zxq5.chatapp.android.data.repository.ChatRepository import dev.zxq5.chatapp.android.feature.auth.AuthScreen
import dev.zxq5.chatapp.android.data.repository.SettingsRepository
import dev.zxq5.chatapp.android.feature.auth.AuthViewModel import dev.zxq5.chatapp.android.feature.auth.AuthViewModel
import dev.zxq5.chatapp.android.feature.chat.ChatScreen
import dev.zxq5.chatapp.android.feature.chat.ChatViewModel import dev.zxq5.chatapp.android.feature.chat.ChatViewModel
import dev.zxq5.chatapp.android.feature.chat.Screen import dev.zxq5.chatapp.android.feature.chat.Screen
import dev.zxq5.chatapp.android.feature.settings.SettingsViewModel import dev.zxq5.chatapp.android.feature.contacts.ContactsScreen
import dev.zxq5.chatapp.android.feature.auth.AuthScreen
import dev.zxq5.chatapp.android.feature.chat.ChatScreen
import dev.zxq5.chatapp.android.feature.settings.SettingsScreen import dev.zxq5.chatapp.android.feature.settings.SettingsScreen
import dev.zxq5.chatapp.android.feature.settings.SettingsViewModel
import dev.zxq5.chatapp.android.ui.theme.ChatappTheme import dev.zxq5.chatapp.android.ui.theme.ChatappTheme
class MainActivity : ComponentActivity() { class MainActivity : ComponentActivity() {
@@ -30,7 +51,6 @@ class MainActivity : ComponentActivity() {
super.onCreate(savedInstanceState) super.onCreate(savedInstanceState)
val app = application as ChatApplication val app = application as ChatApplication
val tokenStore = app.tokenStore
val authRepository = app.authRepository val authRepository = app.authRepository
val chatRepository = app.chatRepository val chatRepository = app.chatRepository
val settingsRepository = app.settingsRepository val settingsRepository = app.settingsRepository
@@ -44,37 +64,149 @@ class MainActivity : ComponentActivity() {
val authState by authViewModel.authState.collectAsState() val authState by authViewModel.authState.collectAsState()
val currentScreen by chatViewModel.currentScreen.collectAsState() val currentScreen by chatViewModel.currentScreen.collectAsState()
val selectedChannelId by chatViewModel.channelId.collectAsState()
Scaffold(modifier = Modifier.fillMaxSize()) { innerPadding -> // Permission request launcher
androidx.compose.foundation.layout.Box(modifier = Modifier.padding(innerPadding)) { val launcher = rememberLauncherForActivityResult(
when (authState) { contract = ActivityResultContracts.RequestPermission(),
AuthState.Authenticated -> { onResult = { isGranted ->
when (currentScreen) { if (isGranted && authState == AuthState.Authenticated) {
Screen.CHAT -> ChatScreen( MessageStreamService.start(this@MainActivity)
viewModel = chatViewModel, }
onNavigateToSettings = { chatViewModel.navigateTo(Screen.SETTINGS) }, }
onLogout = { )
authViewModel.logout()
chatViewModel.clearChat() LaunchedEffect(authState) {
} when (authState) {
) AuthState.Authenticated -> {
Screen.SETTINGS -> SettingsScreen( if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) {
viewModel = settingsViewModel, launcher.launch(Manifest.permission.POST_NOTIFICATIONS)
onBack = { chatViewModel.navigateTo(Screen.CHAT) },
onLogout = {
authViewModel.logout()
chatViewModel.clearChat()
}
)
}
} }
AuthState.AwaitingTotp, AuthState.Unauthenticated -> { MessageStreamService.start(this@MainActivity)
AuthScreen(viewModel = authViewModel) chatViewModel.loadAccessibleChannels()
}
AuthState.Unauthenticated -> MessageStreamService.stop(this@MainActivity)
AuthState.AwaitingTotp -> {}
}
}
LaunchedEffect(Unit) {
chatViewModel.onUnauthorized = {
authViewModel.logout()
chatViewModel.clearChat()
}
}
LaunchedEffect(Unit) {
intent.getIntExtra("channel_id", -1).takeIf { it != -1 }?.let {
chatViewModel.switchChannel(it.toLong())
}
}
if (authState == AuthState.Authenticated) {
Scaffold(
modifier = Modifier.fillMaxSize(),
bottomBar = {
// Only show bottom bar if we are NOT inside a specific chat channel
if (selectedChannelId == null) {
BottomDock(
currentScreen = currentScreen,
onNavigate = { chatViewModel.navigateTo(it) }
)
}
}
) { innerPadding ->
Box(modifier = Modifier.padding(innerPadding)) {
when (currentScreen) {
Screen.CHAT -> ChatScreen(
viewModel = chatViewModel,
onNavigateToSettings = { chatViewModel.navigateTo(Screen.SETTINGS) },
onLogout = {
authViewModel.logout()
chatViewModel.clearChat()
}
)
Screen.CONTACTS -> ContactsScreen()
Screen.SETTINGS -> SettingsScreen(
viewModel = settingsViewModel,
onLogout = {
authViewModel.logout()
chatViewModel.clearChat()
}
)
} }
} }
} }
} else {
AuthScreen(viewModel = authViewModel)
} }
} }
} }
} }
override fun onResume() {
super.onResume()
AppState.isInForeground = true
}
override fun onPause() {
super.onPause()
AppState.isInForeground = false
}
override fun onNewIntent(intent: android.content.Intent) {
super.onNewIntent(intent)
intent.getIntExtra("channel_id", -1).takeIf { it != -1 }?.let { channelId ->
MessageStreamService.instance?.activeChannelId = channelId.toLong()
}
}
}
@Composable
fun BottomDock(currentScreen: Screen, onNavigate: (Screen) -> Unit) {
NavigationBar(
containerColor = MaterialTheme.colorScheme.background,
tonalElevation = 0.dp,
modifier = Modifier
.height(80.dp)
.border(
0.5.dp,
MaterialTheme.colorScheme.outlineVariant.copy(alpha = 0.2f),
RoundedCornerShape(topStart = 0.dp, topEnd = 0.dp)
)
) {
NavigationBarItem(
selected = currentScreen == Screen.CHAT,
onClick = { onNavigate(Screen.CHAT) },
icon = { Icon(Icons.Outlined.ChatBubbleOutline, contentDescription = "Chat") },
label = { Text("chat", style = MaterialTheme.typography.labelSmall) },
colors = NavigationBarItemDefaults.colors(
selectedIconColor = MaterialTheme.colorScheme.primary,
unselectedIconColor = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f),
indicatorColor = Color.Transparent
)
)
NavigationBarItem(
selected = currentScreen == Screen.CONTACTS,
onClick = { onNavigate(Screen.CONTACTS) },
icon = { Icon(Icons.Outlined.PeopleOutline, contentDescription = "Contacts") },
label = { Text("contacts", style = MaterialTheme.typography.labelSmall) },
colors = NavigationBarItemDefaults.colors(
selectedIconColor = MaterialTheme.colorScheme.primary,
unselectedIconColor = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f),
indicatorColor = Color.Transparent
)
)
NavigationBarItem(
selected = currentScreen == Screen.SETTINGS,
onClick = { onNavigate(Screen.SETTINGS) },
icon = { Icon(Icons.Outlined.Settings, contentDescription = "Settings") },
label = { Text("settings", style = MaterialTheme.typography.labelSmall) },
colors = NavigationBarItemDefaults.colors(
selectedIconColor = MaterialTheme.colorScheme.primary,
unselectedIconColor = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f),
indicatorColor = Color.Transparent
)
)
}
} }
@@ -1,10 +1,11 @@
package dev.zxq5.chatapp.android.api package dev.zxq5.chatapp.android.api
import android.util.Log import android.util.Log
import dev.zxq5.chatapp.android.BuildConfig
import dev.zxq5.chatapp.android.BuildConfig.BASE_URL
import dev.zxq5.chatapp.android.api.model.LoginRequest import dev.zxq5.chatapp.android.api.model.LoginRequest
import dev.zxq5.chatapp.android.api.model.LoginResponse import dev.zxq5.chatapp.android.api.model.LoginResponse
import dev.zxq5.chatapp.android.api.model.TOTPSixDigitCode import dev.zxq5.chatapp.android.api.model.TOTPSixDigitCode
import dev.zxq5.chatapp.android.core.BASE_URL
import dev.zxq5.chatapp.android.core.error.ApiResult import dev.zxq5.chatapp.android.core.error.ApiResult
import dev.zxq5.chatapp.android.api.model.SignupRequest import dev.zxq5.chatapp.android.api.model.SignupRequest
import io.ktor.client.HttpClient import io.ktor.client.HttpClient
@@ -1,14 +1,19 @@
@file:OptIn(ExperimentalUuidApi::class)
package dev.zxq5.chatapp.android.api package dev.zxq5.chatapp.android.api
import dev.zxq5.chatapp.android.core.BASE_URL import dev.zxq5.chatapp.android.BuildConfig.BASE_URL
import dev.zxq5.chatapp.android.api.model.Message import dev.zxq5.chatapp.android.api.model.ChatEvent
import dev.zxq5.chatapp.android.api.model.SendMessage import dev.zxq5.chatapp.android.api.model.SendMessage
import dev.zxq5.chatapp.android.api.model.SpaceDto
import io.ktor.client.HttpClient import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.engine.android.Android import io.ktor.client.engine.android.Android
import io.ktor.client.plugins.auth.Auth import io.ktor.client.plugins.auth.Auth
import io.ktor.client.plugins.auth.providers.BearerTokens import io.ktor.client.plugins.auth.providers.BearerTokens
import io.ktor.client.plugins.auth.providers.bearer import io.ktor.client.plugins.auth.providers.bearer
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
import io.ktor.client.request.get
import io.ktor.client.request.post import io.ktor.client.request.post
import io.ktor.client.request.prepareGet import io.ktor.client.request.prepareGet
import io.ktor.client.request.setBody import io.ktor.client.request.setBody
@@ -16,12 +21,17 @@ import io.ktor.client.statement.bodyAsChannel
import io.ktor.http.ContentType import io.ktor.http.ContentType
import io.ktor.http.contentType import io.ktor.http.contentType
import io.ktor.serialization.kotlinx.json.json import io.ktor.serialization.kotlinx.json.json
import io.ktor.utils.io.readUTF8Line import io.ktor.utils.io.readLine
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.flow
import kotlinx.serialization.json.Json import kotlinx.serialization.json.Json
import kotlin.time.Clock
import kotlin.time.ExperimentalTime
import kotlin.uuid.ExperimentalUuidApi
import kotlin.uuid.Uuid
class ChatClient(private val token: String) { class ChatClient(private val token: String) {
private val http = HttpClient(Android) { private val http = HttpClient(Android) {
install(ContentNegotiation) { install(ContentNegotiation) {
json(Json { ignoreUnknownKeys = true }) json(Json { ignoreUnknownKeys = true })
@@ -33,25 +43,27 @@ class ChatClient(private val token: String) {
} }
} }
suspend fun sendMessage(channelId: Int, userId: Int, text: String) { suspend fun getAccessibleChannels(): List<SpaceDto> = http.get("${BASE_URL}/api/accessible_channels").body()
@OptIn(ExperimentalTime::class)
suspend fun sendMessage(channelId: Long, userId: Int, text: String) {
http.post("${BASE_URL}/api/chat/$channelId") { http.post("${BASE_URL}/api/chat/$channelId") {
contentType(ContentType.Application.Json) contentType(ContentType.Application.Json)
setBody(SendMessage(user_id = userId, text = text, timestamp = System.currentTimeMillis())) setBody(SendMessage(id = Uuid.random(), user_id = userId, text = text, timestamp = Clock.System.now()))
} }
} }
fun messageStream(channelId: Int): Flow<Message> = flow { fun eventStream(channelId: Long): Flow<ChatEvent> = flow {
http.prepareGet("${BASE_URL}/api/events/$channelId").execute { response -> http.prepareGet("${BASE_URL}/api/events/$channelId").execute { response ->
val channel = response.bodyAsChannel() val channel = response.bodyAsChannel()
while (!channel.isClosedForRead) { while (!channel.isClosedForRead) {
val line = channel.readUTF8Line(256) ?: break val line = channel.readLine() ?: break
if (line.startsWith("data:")) { if (line.startsWith("data:")) {
val json = line.removePrefix("data:").trim() val json = line.removePrefix("data:").trim()
runCatching { Json.decodeFromString<Message>(json) } runCatching { Json.decodeFromString<ChatEvent>(json) }
.onSuccess { emit(it) } .onSuccess { emit(it) }
} }
} }
} }
} }
} }
@@ -1,8 +1,10 @@
package dev.zxq5.chatapp.android.api package dev.zxq5.chatapp.android.api
import android.util.Log import android.util.Log
import dev.zxq5.chatapp.android.BuildConfig.BASE_URL
import dev.zxq5.chatapp.android.api.model.AccountDeleteRequest import dev.zxq5.chatapp.android.api.model.AccountDeleteRequest
import dev.zxq5.chatapp.android.api.model.DisplayNameRequest import dev.zxq5.chatapp.android.api.model.DisplayNameRequest
import dev.zxq5.chatapp.android.api.model.InviteRequest
import dev.zxq5.chatapp.android.api.model.PasswordChangeRequest import dev.zxq5.chatapp.android.api.model.PasswordChangeRequest
import dev.zxq5.chatapp.android.api.model.QrResponse import dev.zxq5.chatapp.android.api.model.QrResponse
import dev.zxq5.chatapp.android.api.model.TOTPSixDigitCode import dev.zxq5.chatapp.android.api.model.TOTPSixDigitCode
@@ -10,7 +12,6 @@ import dev.zxq5.chatapp.android.api.model.TotpStatus
import dev.zxq5.chatapp.android.api.model.UsernameRequest import dev.zxq5.chatapp.android.api.model.UsernameRequest
import dev.zxq5.chatapp.android.api.model.TotpDeleteRequest import dev.zxq5.chatapp.android.api.model.TotpDeleteRequest
import dev.zxq5.chatapp.android.api.model.PasswordRequest import dev.zxq5.chatapp.android.api.model.PasswordRequest
import dev.zxq5.chatapp.android.core.BASE_URL
import dev.zxq5.chatapp.android.core.error.ApiResult import dev.zxq5.chatapp.android.core.error.ApiResult
import io.ktor.client.HttpClient import io.ktor.client.HttpClient
import io.ktor.client.call.body import io.ktor.client.call.body
@@ -45,6 +46,22 @@ class SettingsClient(private val token: String) {
} }
} }
suspend fun createInvite(request: InviteRequest): ApiResult<String> {
return try {
val response = http.post("${BASE_URL}/api/invite") {
contentType(ContentType.Application.Json)
setBody(request)
}
if (response.status.isSuccess()) {
ApiResult.Success(response.body<String>())
} else {
ApiResult.HttpError(response.status.value, "Failed to create invite")
}
} catch (e: Exception) {
ApiResult.NetworkError(e.localizedMessage ?: "Network error")
}
}
suspend fun getTotpQr(password: String): ApiResult<QrResponse> { suspend fun getTotpQr(password: String): ApiResult<QrResponse> {
return try { return try {
val response = http.post("${BASE_URL}/api/totp.jpg") { val response = http.post("${BASE_URL}/api/totp.jpg") {
@@ -0,0 +1,15 @@
package dev.zxq5.chatapp.android.api.model
import kotlinx.serialization.Serializable
import kotlin.time.ExperimentalTime
import kotlin.time.Instant
@Serializable
data class Channel @OptIn(ExperimentalTime::class) constructor(
val id: Long,
val name: String,
val description: String? = null,
val space_id: Long,
val created_at: Instant,
val updated_at: Instant
)
@@ -0,0 +1,60 @@
@file:OptIn(ExperimentalUuidApi::class, ExperimentalTime::class)
package dev.zxq5.chatapp.android.api.model
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.SerialName
import kotlinx.serialization.json.JsonClassDiscriminator
import kotlin.time.ExperimentalTime
import kotlin.time.Instant
import kotlinx.serialization.Serializable
import kotlin.uuid.ExperimentalUuidApi
import kotlin.uuid.Uuid
@OptIn(ExperimentalSerializationApi::class)
@Serializable
@JsonClassDiscriminator("type")
sealed class ChatEvent {
@Serializable
@SerialName("SendMessage")
data class SendMessage(
val data: Message
) : ChatEvent()
@Serializable
@SerialName("EditMessage")
data class EditMessage(
val data: EditMessageContent
) : ChatEvent()
@Serializable
@SerialName("MessageAppendContent")
data class MessageAppendContent(
val data: AppendContent
) : ChatEvent()
}
// tuple variants like (i64, ChatMsg) and (i64, String)
// need wrapper classes since kotlinx can't deserialise
// bare JSON arrays into data classes directly
@Serializable
data class EditMessageContent(
val id: Uuid,
val message: Message
)
@Serializable
data class AppendContent (
val id: Uuid,
val content: String
)
@Serializable
data class Message (
val id: Uuid,
val user_id: Int,
val display_name: String,
val text: String,
val timestamp: Instant
)
@@ -0,0 +1,13 @@
package dev.zxq5.chatapp.android.api.model
import kotlinx.serialization.Serializable
import kotlin.time.ExperimentalTime
import kotlin.time.Instant
@Serializable
data class InviteRequest @OptIn(ExperimentalTime::class) constructor(
val name: String,
val max_uses: Int,
val expiry_date: Instant,
val start_date: Instant
)
@@ -1,4 +1,4 @@
package dev.zxq5.chatapp.android.model package dev.zxq5.chatapp.android.api.model
sealed class LoginState { sealed class LoginState {
object Idle : LoginState() object Idle : LoginState()
@@ -1,11 +0,0 @@
package dev.zxq5.chatapp.android.api.model
import kotlinx.serialization.Serializable
@Serializable
data class Message(
val user_id: Int,
val display_name: String,
val text: String,
val timestamp: Long
)
@@ -1,10 +1,15 @@
package dev.zxq5.chatapp.android.api.model package dev.zxq5.chatapp.android.api.model
import kotlinx.serialization.Serializable import kotlinx.serialization.Serializable
import kotlin.time.ExperimentalTime
import kotlin.time.Instant
import kotlin.uuid.ExperimentalUuidApi
import kotlin.uuid.Uuid
@Serializable @Serializable
data class SendMessage( data class SendMessage @OptIn(ExperimentalTime::class, ExperimentalUuidApi::class) constructor(
val id: Uuid,
val user_id: Int, val user_id: Int,
val text: String, val text: String,
val timestamp: Long val timestamp: Instant
) )
@@ -0,0 +1,28 @@
package dev.zxq5.chatapp.android.api.model
import kotlinx.serialization.Serializable
import kotlin.time.ExperimentalTime
import kotlin.time.Instant
@Serializable
data class Space @OptIn(ExperimentalTime::class) constructor(
val id: Long,
val name: String,
val description: String? = null,
val owner_id: Long,
val created_at: Instant,
val updated_at: Instant
)
@Serializable
data class SpaceDto @OptIn(ExperimentalTime::class) constructor(
val channels: List<Channel>,
val id: Long,
val name: String,
val description: String? = null,
val owner_id: Long,
val created_at: Instant,
val updated_at: Instant
)
@@ -1,3 +1,3 @@
package dev.zxq5.chatapp.android.core package dev.zxq5.chatapp.android.core
const val BASE_URL = "http://zxq5-x1:8000" //const val BASE_URL = "http://zxq5-x1:8000"
@@ -3,10 +3,12 @@ package dev.zxq5.chatapp.android.core.data
import android.content.Context import android.content.Context
import android.content.SharedPreferences import android.content.SharedPreferences
import android.util.Base64 import android.util.Base64
import android.util.Log
import androidx.core.content.edit import androidx.core.content.edit
import androidx.security.crypto.EncryptedSharedPreferences import androidx.security.crypto.EncryptedSharedPreferences
import androidx.security.crypto.MasterKey import androidx.security.crypto.MasterKey
import org.json.JSONObject import org.json.JSONObject
import java.time.Instant
private const val KEY = "auth_token" private const val KEY = "auth_token"
private const val TWOFA_KEY = "twofa_enabled" private const val TWOFA_KEY = "twofa_enabled"
@@ -27,11 +29,37 @@ class TokenStore(appContext: Context) {
) )
} }
fun save(token: String) = fun save(token: String) {
prefs().edit { putString(KEY, token) } Log.d("TokenStore", "Saving token: $token")
prefs().edit { putString(KEY, token) }
}
fun get(): String? {
val ret = prefs().getString(KEY, null)
Log.d("TokenStore", "Retrieved token: $ret")
return ret
}
fun isExpired(): Boolean {
val token = get() ?: return true
return try {
val payload = token.split(".")[1]
val padded = payload + "==".take((4 - payload.length % 4) % 4)
val jsonString = String(Base64.decode(padded, Base64.URL_SAFE))
val json = JSONObject(jsonString)
if (json.has("exp")) {
val exp = json.getLong("exp")
val now = Instant.now().epochSecond
now >= exp
} else {
false // If no exp claim, assume not expired or handle differently
}
} catch (e: Exception) {
true // If we can't parse it, treat it as expired
}
}
fun get(): String? =
prefs().getString(KEY, null)
fun save2faEnabled( enabled: Boolean) = fun save2faEnabled( enabled: Boolean) =
prefs().edit { putBoolean(TWOFA_KEY, enabled) } prefs().edit { putBoolean(TWOFA_KEY, enabled) }
@@ -0,0 +1,92 @@
package dev.zxq5.chatapp.android.core.service
import android.app.Service
import android.content.Context
import android.content.Intent
import android.os.IBinder
import android.util.Log
import dev.zxq5.chatapp.android.ChatApplication
import dev.zxq5.chatapp.android.api.model.ChatEvent
import dev.zxq5.chatapp.android.data.repository.ChatRepository
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.cancel
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.launch
class MessageStreamService : Service() {
private val serviceScope = CoroutineScope(SupervisorJob() + Dispatchers.IO)
private lateinit var notificationService: NotificationService
private lateinit var chatRepository: ChatRepository
var activeChannelId: Long? = null
set(value) {
field = value
Log.d("Service", "activeChannelId set to $value")
if (value != null) {
currentStreamJob?.cancel()
observeMessages()
}
}
private var currentStreamJob: kotlinx.coroutines.Job? = null
companion object {
var instance: MessageStreamService? = null
fun start(context: Context) {
val intent = Intent(context, MessageStreamService::class.java)
context.startService(intent)
}
fun stop(context: Context) {
context.stopService(Intent(context, MessageStreamService::class.java))
}
}
override fun onCreate() {
super.onCreate()
instance = this
notificationService = NotificationService(this)
chatRepository = (application as ChatApplication).chatRepository
}
override fun onStartCommand(intent: Intent?, flags: Int, startId: Int): Int {
observeMessages()
return START_STICKY
}
private fun observeMessages() {
val channelId = activeChannelId ?: chatRepository.getLastActiveChannel()
if (channelId == null) return
currentStreamJob = serviceScope.launch {
chatRepository.eventStream(channelId)
.catch { e -> Log.e("Service", "Stream error", e) }
.collect { event ->
// Only show notification when an event (new message) is received
// and the app is not in the foreground on this channel.
if (!ChatApplication.AppState.isInForeground || activeChannelId != channelId) {
when (event) {
is ChatEvent.SendMessage -> notificationService.showMessageNotification(
conversationId = channelId.toString(),
senderName = event.data.display_name,
messagePreview = event.data.text
)
else -> {}
}
}
}
}
}
override fun onBind(intent: Intent?): IBinder? = null
override fun onDestroy() {
super.onDestroy()
instance = null
serviceScope.cancel()
}
}
@@ -0,0 +1,84 @@
package dev.zxq5.chatapp.android.core.service
import android.app.Notification
import android.app.NotificationChannel
import android.app.NotificationManager
import android.app.PendingIntent
import android.content.Context
import android.content.Intent
import android.os.Build
import androidx.core.app.NotificationCompat
import dev.zxq5.chatapp.android.MainActivity
import dev.zxq5.chatapp.android.R
class NotificationService(private val context: Context) {
companion object {
const val CHANNEL_ID = "messages"
const val SERVICE_CHANNEL_ID = "service"
const val FOREGROUND_NOTIFICATION_ID = 1
}
private val manager = context.getSystemService(NotificationManager::class.java)
fun createForegroundNotification(): Notification {
val intent = Intent(context, MainActivity::class.java).apply {
flags = Intent.FLAG_ACTIVITY_NEW_TASK or Intent.FLAG_ACTIVITY_CLEAR_TASK
}
val pendingIntent = PendingIntent.getActivity(
context,
0,
intent,
PendingIntent.FLAG_IMMUTABLE
)
return NotificationCompat.Builder(context, SERVICE_CHANNEL_ID)
.setSmallIcon(R.drawable.ic_notification)
.setContentTitle("Chat App")
.setContentText("Connecting to message stream...")
.setPriority(NotificationCompat.PRIORITY_LOW)
.setCategory(NotificationCompat.CATEGORY_SERVICE)
.setContentIntent(pendingIntent)
.setOngoing(true)
.build()
}
fun showMessageNotification(
conversationId: String,
senderName: String,
messagePreview: String,
notificationId: Int = conversationId.hashCode()
) {
val intent = Intent(context, MainActivity::class.java).apply {
flags = Intent.FLAG_ACTIVITY_SINGLE_TOP
putExtra("conversation_id", conversationId)
}
val pendingIntent = PendingIntent.getActivity(
context,
notificationId,
intent,
PendingIntent.FLAG_UPDATE_CURRENT or PendingIntent.FLAG_IMMUTABLE
)
val notification = NotificationCompat.Builder(context, CHANNEL_ID)
.setSmallIcon(R.drawable.ic_notification)
.setContentTitle(senderName)
.setContentText(messagePreview)
.setPriority(NotificationCompat.PRIORITY_HIGH)
.setContentIntent(pendingIntent)
.setAutoCancel(true)
.build()
manager.notify(notificationId, notification)
}
fun dismissNotification(conversationId: String) {
manager.cancel(conversationId.hashCode())
}
fun dismissAll() {
manager.cancelAll()
}
}
@@ -56,6 +56,10 @@ class AuthRepository(
fun getAuthState(): AuthState { fun getAuthState(): AuthState {
val token = tokenStore.get() ?: return AuthState.Unauthenticated val token = tokenStore.get() ?: return AuthState.Unauthenticated
if (tokenStore.isExpired()) {
tokenStore.clear()
return AuthState.Unauthenticated
}
return when (getScopeFromToken(token)) { return when (getScopeFromToken(token)) {
TokenScope.FULL -> AuthState.Authenticated TokenScope.FULL -> AuthState.Authenticated
TokenScope.TOTP_PENDING -> AuthState.AwaitingTotp TokenScope.TOTP_PENDING -> AuthState.AwaitingTotp
@@ -1,8 +1,10 @@
package dev.zxq5.chatapp.android.data.repository package dev.zxq5.chatapp.android.data.repository
import dev.zxq5.chatapp.android.api.ChatClient import dev.zxq5.chatapp.android.api.ChatClient
import dev.zxq5.chatapp.android.api.model.ChatEvent
import dev.zxq5.chatapp.android.core.data.TokenStore import dev.zxq5.chatapp.android.core.data.TokenStore
import dev.zxq5.chatapp.android.api.model.Message import dev.zxq5.chatapp.android.api.model.Message
import dev.zxq5.chatapp.android.api.model.SpaceDto
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.emptyFlow import kotlinx.coroutines.flow.emptyFlow
@@ -11,6 +13,8 @@ class ChatRepository(private val tokenStore: TokenStore) {
private var _chatClient: ChatClient? = null private var _chatClient: ChatClient? = null
private var _lastToken: String? = null private var _lastToken: String? = null
private var _lastActiveChannel: Long? = null
private fun getChatClient(): ChatClient? { private fun getChatClient(): ChatClient? {
val token = tokenStore.get() ?: return null val token = tokenStore.get() ?: return null
if (_chatClient == null || token != _lastToken) { if (_chatClient == null || token != _lastToken) {
@@ -25,14 +29,23 @@ class ChatRepository(private val tokenStore: TokenStore) {
_lastToken = null _lastToken = null
} }
fun getLastActiveChannel(): Long? {
return _lastActiveChannel
}
fun getUserId() = tokenStore.getUserId() fun getUserId() = tokenStore.getUserId()
suspend fun sendMessage(channelId: Int, text: String) { suspend fun getAccessibleChannels(): List<SpaceDto> {
return getChatClient()?.getAccessibleChannels() ?: emptyList()
}
suspend fun sendMessage(channelId: Long, text: String) {
val userId = tokenStore.getUserId() ?: return val userId = tokenStore.getUserId() ?: return
getChatClient()?.sendMessage(channelId, userId, text) getChatClient()?.sendMessage(channelId, userId, text)
} }
fun messageStream(channelId: Int): Flow<Message> { fun eventStream(channelId: Long): Flow<ChatEvent> {
return getChatClient()?.messageStream(channelId) ?: emptyFlow() _lastActiveChannel = channelId
return getChatClient()?.eventStream(channelId) ?: emptyFlow()
} }
} }
@@ -2,6 +2,7 @@ package dev.zxq5.chatapp.android.data.repository
import dev.zxq5.chatapp.android.api.model.QrResponse import dev.zxq5.chatapp.android.api.model.QrResponse
import dev.zxq5.chatapp.android.api.SettingsClient import dev.zxq5.chatapp.android.api.SettingsClient
import dev.zxq5.chatapp.android.api.model.InviteRequest
import dev.zxq5.chatapp.android.api.model.TotpStatus import dev.zxq5.chatapp.android.api.model.TotpStatus
import dev.zxq5.chatapp.android.core.data.TokenStore import dev.zxq5.chatapp.android.core.data.TokenStore
import dev.zxq5.chatapp.android.core.error.ApiResult import dev.zxq5.chatapp.android.core.error.ApiResult
@@ -25,6 +26,10 @@ class SettingsRepository(private val tokenStore: TokenStore) {
_lastToken = null _lastToken = null
} }
suspend fun createInvite(request: InviteRequest): ApiResult<String> {
return getSettingsClient()?.createInvite(request) ?: ApiResult.NetworkError("Not authenticated")
}
suspend fun getTotpQr(password: String): ApiResult<QrResponse?> { suspend fun getTotpQr(password: String): ApiResult<QrResponse?> {
val settingsClient = getSettingsClient() ?: return ApiResult.NetworkError("Not authenticated") val settingsClient = getSettingsClient() ?: return ApiResult.NetworkError("Not authenticated")
return settingsClient.getTotpQr(password) return settingsClient.getTotpQr(password)
@@ -3,7 +3,7 @@ package dev.zxq5.chatapp.android.feature.auth
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.collectAsState import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import dev.zxq5.chatapp.android.model.LoginState import dev.zxq5.chatapp.android.api.model.LoginState
@Composable @Composable
fun AuthScreen(viewModel: AuthViewModel) { fun AuthScreen(viewModel: AuthViewModel) {
@@ -2,11 +2,12 @@ package dev.zxq5.chatapp.android.feature.auth
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import dev.zxq5.chatapp.android.core.service.MessageStreamService
import dev.zxq5.chatapp.android.data.repository.AuthRepository import dev.zxq5.chatapp.android.data.repository.AuthRepository
import dev.zxq5.chatapp.android.data.repository.LoginResult import dev.zxq5.chatapp.android.data.repository.LoginResult
import dev.zxq5.chatapp.android.data.repository.SignupResult import dev.zxq5.chatapp.android.data.repository.SignupResult
import dev.zxq5.chatapp.android.data.repository.AuthState import dev.zxq5.chatapp.android.data.repository.AuthState
import dev.zxq5.chatapp.android.model.LoginState import dev.zxq5.chatapp.android.api.model.LoginState
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
@@ -28,7 +28,7 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.Color
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import dev.zxq5.chatapp.android.model.LoginState import dev.zxq5.chatapp.android.api.model.LoginState
import dev.zxq5.chatapp.android.ui.components.TextField import dev.zxq5.chatapp.android.ui.components.TextField
@Composable @Composable
@@ -28,7 +28,7 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.Color
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import dev.zxq5.chatapp.android.model.LoginState import dev.zxq5.chatapp.android.api.model.LoginState
import dev.zxq5.chatapp.android.ui.components.TextField import dev.zxq5.chatapp.android.ui.components.TextField
@Composable @Composable
@@ -1,26 +1,33 @@
@file:OptIn(ExperimentalUuidApi::class)
package dev.zxq5.chatapp.android.feature.chat package dev.zxq5.chatapp.android.feature.chat
import android.util.Log import android.util.Log
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import dev.zxq5.chatapp.android.api.model.ChatEvent
import dev.zxq5.chatapp.android.data.repository.ChatRepository import dev.zxq5.chatapp.android.data.repository.ChatRepository
import dev.zxq5.chatapp.android.api.model.Message import dev.zxq5.chatapp.android.api.model.Message
import dev.zxq5.chatapp.android.api.model.SpaceDto
import dev.zxq5.chatapp.android.core.service.MessageStreamService
import io.ktor.client.plugins.ResponseException
import io.ktor.http.HttpStatusCode
import kotlinx.coroutines.Job import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.catch import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.update import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlin.time.ExperimentalTime
import kotlin.uuid.ExperimentalUuidApi
class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() { class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
private val _messages = MutableStateFlow<List<Message>>(emptyList()) private val _messages = MutableStateFlow<List<Message>>(emptyList())
val messages: StateFlow<List<Message>> = _messages val messages: StateFlow<List<Message>> = _messages
private val _channelId = MutableStateFlow<Int?>(null) private val _channelId = MutableStateFlow<Long?>(null)
val channelId: StateFlow<Int?> = _channelId val channelId: StateFlow<Long?> = _channelId
private val _currentScreen = MutableStateFlow(Screen.CHAT) private val _currentScreen = MutableStateFlow(Screen.CHAT)
val currentScreen: StateFlow<Screen> = _currentScreen val currentScreen: StateFlow<Screen> = _currentScreen
@@ -28,26 +35,87 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
private val _currentUserId = MutableStateFlow<Int?>(null) private val _currentUserId = MutableStateFlow<Int?>(null)
val currentUserId: StateFlow<Int?> = _currentUserId val currentUserId: StateFlow<Int?> = _currentUserId
private val _spaces = MutableStateFlow<List<SpaceDto>>(emptyList())
val spaces: StateFlow<List<SpaceDto>> = _spaces
private val _error = MutableStateFlow<String?>(null)
val error: StateFlow<String?> = _error
private val _channelError = MutableStateFlow<String?>(null)
val channelError: StateFlow<String?> = _channelError
private var streamJob: Job? = null private var streamJob: Job? = null
var onUnauthorized: (() -> Unit)? = null
init { init {
_currentUserId.value = chatRepository.getUserId() _currentUserId.value = chatRepository.getUserId()
observeChannel() observeChannel()
loadAccessibleChannels()
} }
fun loadAccessibleChannels() {
_error.value = null
_currentUserId.value = chatRepository.getUserId()
viewModelScope.launch {
runCatching {
chatRepository.getAccessibleChannels()
}.onSuccess { data ->
_spaces.value = data
}.onFailure { e ->
Log.e("Chat", "Failed to load spaces", e)
if (e is ResponseException && e.response.status == HttpStatusCode.Unauthorized) {
onUnauthorized?.invoke()
} else {
_error.value = "Failed to load channels: ${e.message}"
}
}
}
}
@OptIn(ExperimentalTime::class)
private fun observeChannel() { private fun observeChannel() {
viewModelScope.launch { viewModelScope.launch {
_channelId.collect { id -> _channelId.collect { id ->
streamJob?.cancel() streamJob?.cancel()
_messages.value = emptyList() _messages.value = emptyList()
_channelError.value = null
if (id != null) { if (id != null) {
streamJob = launch { streamJob = launch {
chatRepository.messageStream(id) chatRepository.eventStream(id)
.catch { e -> .catch { e ->
Log.e("Chat", "Stream error", e) Log.e("Chat", "Stream error", e)
if (e is ResponseException && e.response.status == HttpStatusCode.Unauthorized) {
onUnauthorized?.invoke()
} else {
_channelError.value = "Connection lost: ${e.message}"
}
} }
.collect { message -> .collect { event ->
_messages.update { it + message } when (event) {
is ChatEvent.SendMessage -> {
_messages.update { it + event.data }
}
is ChatEvent.EditMessage -> {
_messages.update { messages ->
messages.map {
if (it.id == event.data.id) event.data.message
else it
}
}
}
is ChatEvent.MessageAppendContent -> {
_messages.update { messages ->
messages.map { msg ->
if (msg.id == event.data.id) {
msg.copy(text = msg.text + event.data.content)
} else {
msg
}
}
}
}
}
} }
} }
} }
@@ -59,12 +127,14 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
_currentScreen.value = screen _currentScreen.value = screen
} }
fun switchChannel(id: Int?) { fun switchChannel(id: Long?) {
_channelId.value = id _channelId.value = id
MessageStreamService.instance?.activeChannelId = id
if (id != null) { if (id != null) {
// Refresh user ID just in case it wasn't available at init // Refresh user ID just in case it wasn't available at init
_currentUserId.value = chatRepository.getUserId() _currentUserId.value = chatRepository.getUserId()
chatRepository.resetClient()
} }
} }
@@ -78,6 +148,11 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
) )
}.onFailure { e -> }.onFailure { e ->
Log.e("Chat", "Send message error", e) Log.e("Chat", "Send message error", e)
if (e is ResponseException && e.response.status == HttpStatusCode.Unauthorized) {
onUnauthorized?.invoke()
} else {
_channelError.value = "Failed to send message"
}
} }
} }
} }
@@ -86,8 +161,15 @@ class ChatViewModel(private val chatRepository: ChatRepository) : ViewModel() {
_messages.value = emptyList() _messages.value = emptyList()
_channelId.value = null _channelId.value = null
_currentUserId.value = null _currentUserId.value = null
_spaces.value = emptyList()
_error.value = null
_channelError.value = null
streamJob?.cancel() streamJob?.cancel()
chatRepository.resetClient() chatRepository.resetClient()
MessageStreamService.instance?.activeChannelId = null
}
fun clearChannelError() {
_channelError.value = null
} }
} }
@@ -1,5 +1,5 @@
package dev.zxq5.chatapp.android.feature.chat package dev.zxq5.chatapp.android.feature.chat
enum class Screen { enum class Screen {
CHAT, SETTINGS CHAT, CONTACTS, SETTINGS
} }
@@ -1,3 +1,5 @@
@file:OptIn(ExperimentalUuidApi::class)
package dev.zxq5.chatapp.android.feature.chat package dev.zxq5.chatapp.android.feature.chat
import androidx.compose.foundation.BorderStroke import androidx.compose.foundation.BorderStroke
@@ -28,22 +30,21 @@ import androidx.compose.foundation.text.KeyboardOptions
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.filled.ArrowBack import androidx.compose.material.icons.automirrored.filled.ArrowBack
import androidx.compose.material.icons.filled.Add import androidx.compose.material.icons.filled.Add
import androidx.compose.material.icons.filled.Refresh
import androidx.compose.material.icons.filled.Send
import androidx.compose.material3.Button
import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.HorizontalDivider import androidx.compose.material3.HorizontalDivider
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton import androidx.compose.material3.IconButton
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.NavigationBar
import androidx.compose.material3.NavigationBarItem
import androidx.compose.material3.NavigationBarItemDefaults
import androidx.compose.material3.Scaffold import androidx.compose.material3.Scaffold
import androidx.compose.material3.SnackbarHost
import androidx.compose.material3.SnackbarHostState
import androidx.compose.material3.Surface import androidx.compose.material3.Surface
import androidx.compose.material3.Text import androidx.compose.material3.Text
import androidx.compose.material3.TopAppBar import androidx.compose.material3.TopAppBar
import androidx.compose.material3.TopAppBarDefaults import androidx.compose.material3.TopAppBarDefaults
import androidx.compose.material.icons.filled.Send
import androidx.compose.material.icons.outlined.ChatBubbleOutline
import androidx.compose.material.icons.outlined.Settings
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState import androidx.compose.runtime.collectAsState
@@ -56,26 +57,30 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.SolidColor import androidx.compose.ui.graphics.SolidColor
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.input.ImeAction import androidx.compose.ui.text.input.ImeAction
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.unit.Dp import androidx.compose.ui.unit.Dp
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import dev.zxq5.chatapp.android.api.model.Channel
import dev.zxq5.chatapp.android.api.model.Message import dev.zxq5.chatapp.android.api.model.Message
import java.text.DateFormat import java.text.DateFormat
import java.util.Date import java.util.Date
import kotlin.time.ExperimentalTime
import kotlin.uuid.ExperimentalUuidApi
@Composable @Composable
fun ChatScreen( fun ChatScreen(
viewModel: ChatViewModel, viewModel: ChatViewModel,
onNavigateToSettings: () -> Unit, onNavigateToSettings: () -> Unit,
onLogout: () -> Unit // Note: logout is now part of SettingsScreen in this UI, but we'll keep the param for now onLogout: () -> Unit
) { ) {
val selectedChannelId by viewModel.channelId.collectAsState() val selectedChannelId by viewModel.channelId.collectAsState()
if (selectedChannelId == null) { if (selectedChannelId == null) {
ChannelListScreen( ChannelListScreen(
viewModel = viewModel, viewModel = viewModel,
onChannelSelect = { viewModel.switchChannel(it) }, onChannelSelect = { viewModel.switchChannel(it) }
onNavigateToSettings = onNavigateToSettings
) )
} else { } else {
MessageScreen( MessageScreen(
@@ -90,20 +95,15 @@ fun ChatScreen(
@Composable @Composable
fun ChannelListScreen( fun ChannelListScreen(
viewModel: ChatViewModel, viewModel: ChatViewModel,
onChannelSelect: (Int) -> Unit, onChannelSelect: (Long) -> Unit
onNavigateToSettings: () -> Unit
) { ) {
val spaces by viewModel.spaces.collectAsState()
val error by viewModel.error.collectAsState()
Scaffold( Scaffold(
containerColor = MaterialTheme.colorScheme.background, containerColor = MaterialTheme.colorScheme.background,
topBar = { topBar = {
Column { Column {
Spacer(Modifier.height(8.dp))
Text(
"contacts",
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f),
modifier = Modifier.padding(horizontal = 20.dp)
)
TopAppBar( TopAppBar(
title = { title = {
Text( Text(
@@ -115,103 +115,69 @@ fun ChannelListScreen(
colors = TopAppBarDefaults.topAppBarColors( colors = TopAppBarDefaults.topAppBarColors(
containerColor = Color.Transparent, containerColor = Color.Transparent,
titleContentColor = MaterialTheme.colorScheme.onSurface titleContentColor = MaterialTheme.colorScheme.onSurface
) ),
windowInsets = androidx.compose.foundation.layout.WindowInsets(0, 0, 0, 0),
) )
Text( Text(
"5 channels · end-to-end encrypted", "Public channels - dms coming soon.",
style = MaterialTheme.typography.labelSmall, style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.4f), color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.4f),
modifier = Modifier.padding(horizontal = 20.dp, vertical = 2.dp) modifier = Modifier.padding(horizontal = 20.dp, vertical = 2.dp)
) )
Spacer(Modifier.height(12.dp)) Spacer(Modifier.height(12.dp))
}
Row( }
modifier = Modifier ) { padding ->
.fillMaxWidth() if (error != null) {
.background(MaterialTheme.colorScheme.secondary.copy(alpha = 0.2f)) Column(
.padding(horizontal = 20.dp, vertical = 8.dp), modifier = Modifier.fillMaxSize().padding(padding),
verticalAlignment = Alignment.CenterVertically horizontalAlignment = Alignment.CenterHorizontally,
) { verticalArrangement = Arrangement.Center
Box( ) {
modifier = Modifier Text(
.size(6.dp) text = error!!,
.clip(CircleShape) color = MaterialTheme.colorScheme.error,
.background(MaterialTheme.colorScheme.primary) textAlign = TextAlign.Center,
) modifier = Modifier.padding(16.dp)
Spacer(Modifier.width(10.dp)) )
Text( Button(onClick = { viewModel.loadAccessibleChannels() }) {
"global · walkie talkie", Icon(Icons.Default.Refresh, contentDescription = null)
style = MaterialTheme.typography.labelSmall, Spacer(Modifier.width(8.dp))
color = MaterialTheme.colorScheme.onSurfaceVariant, Text("Retry")
modifier = Modifier.weight(1f)
)
Surface(
color = Color.Transparent,
border = border(0.5.dp, MaterialTheme.colorScheme.outlineVariant),
shape = RoundedCornerShape(4.dp)
) {
Text(
"hold to talk",
modifier = Modifier.padding(horizontal = 8.dp, vertical = 3.dp),
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
} }
} }
}, } else {
bottomBar = { BottomDock(viewModel, onNavigateToSettings) } LazyColumn(modifier = Modifier.padding(padding).fillMaxSize()) {
) { padding -> spaces.forEach { spaceDto ->
LazyColumn(modifier = Modifier.padding(padding).fillMaxSize()) { item {
items(10) { i -> Text(
val id = i + 1 text = spaceDto.name.lowercase(),
ChannelItem(id = id, onClick = { onChannelSelect(id) }) modifier = Modifier.padding(horizontal = 20.dp, vertical = 8.dp),
HorizontalDivider( style = MaterialTheme.typography.labelMedium,
modifier = Modifier.padding(horizontal = 20.dp), color = MaterialTheme.colorScheme.primary,
thickness = 0.5.dp, fontWeight = FontWeight.Bold
color = MaterialTheme.colorScheme.outlineVariant.copy(alpha = 0.2f) )
) }
items(spaceDto.channels) { channel ->
ChannelItem(channel = channel, onClick = { onChannelSelect(channel.id) })
HorizontalDivider(
modifier = Modifier.padding(horizontal = 20.dp),
thickness = 0.5.dp,
color = MaterialTheme.colorScheme.outlineVariant.copy(alpha = 0.2f)
)
}
item {
Spacer(Modifier.height(16.dp))
}
}
} }
} }
} }
} }
@Composable @Composable
fun BottomDock(viewModel: ChatViewModel, onNavigateToSettings: () -> Unit) { fun ChannelItem(channel: Channel, onClick: () -> Unit) {
val currentScreen by viewModel.currentScreen.collectAsState()
NavigationBar(
containerColor = MaterialTheme.colorScheme.background,
tonalElevation = 0.dp,
modifier = Modifier.border(0.5.dp, MaterialTheme.colorScheme.outlineVariant.copy(alpha = 0.2f), RoundedCornerShape(topStart = 0.dp, topEnd = 0.dp))
) {
NavigationBarItem(
selected = currentScreen == Screen.CHAT,
onClick = { viewModel.navigateTo(Screen.CHAT) },
icon = { Icon(Icons.Outlined.ChatBubbleOutline, contentDescription = "Chat") },
label = { Text("chat", style = MaterialTheme.typography.labelSmall) },
colors = NavigationBarItemDefaults.colors(
selectedIconColor = MaterialTheme.colorScheme.primary,
unselectedIconColor = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f),
indicatorColor = Color.Transparent
)
)
NavigationBarItem(
selected = currentScreen == Screen.SETTINGS,
onClick = onNavigateToSettings,
icon = { Icon(Icons.Outlined.Settings, contentDescription = "Settings") },
label = { Text("settings", style = MaterialTheme.typography.labelSmall) },
colors = NavigationBarItemDefaults.colors(
selectedIconColor = MaterialTheme.colorScheme.primary,
unselectedIconColor = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f),
indicatorColor = Color.Transparent
)
)
}
}
@Composable
fun ChannelItem(id: Int, onClick: () -> Unit) {
Row( Row(
modifier = Modifier modifier = Modifier
.fillMaxWidth() .fillMaxWidth()
@@ -227,7 +193,7 @@ fun ChannelItem(id: Int, onClick: () -> Unit) {
contentAlignment = Alignment.Center contentAlignment = Alignment.Center
) { ) {
Text( Text(
"C$id", channel.name.take(1).uppercase(),
style = MaterialTheme.typography.labelSmall, style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onSurfaceVariant color = MaterialTheme.colorScheme.onSurfaceVariant
) )
@@ -235,31 +201,30 @@ fun ChannelItem(id: Int, onClick: () -> Unit) {
Spacer(Modifier.width(12.dp)) Spacer(Modifier.width(12.dp))
Column(modifier = Modifier.weight(1f)) { Column(modifier = Modifier.weight(1f)) {
Text( Text(
text = "channel $id", text = channel.name,
style = MaterialTheme.typography.bodyLarge, style = MaterialTheme.typography.bodyLarge,
color = MaterialTheme.colorScheme.onSurface color = MaterialTheme.colorScheme.onSurface
) )
Text( if (channel.description != null) {
text = "tap to join", Text(
style = MaterialTheme.typography.labelSmall, text = channel.description,
color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f) style = MaterialTheme.typography.labelSmall,
) color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f)
)
}
} }
Text(
"14:22",
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.3f)
)
} }
} }
@OptIn(ExperimentalMaterial3Api::class) @OptIn(ExperimentalMaterial3Api::class)
@Composable @Composable
fun MessageScreen(channelId: Int, viewModel: ChatViewModel, onBack: () -> Unit) { fun MessageScreen(channelId: Long, viewModel: ChatViewModel, onBack: () -> Unit) {
val messages by viewModel.messages.collectAsState() val messages by viewModel.messages.collectAsState()
val currentUserId by viewModel.currentUserId.collectAsState() val currentUserId by viewModel.currentUserId.collectAsState()
val channelError by viewModel.channelError.collectAsState()
var input by remember { mutableStateOf("") } var input by remember { mutableStateOf("") }
val listState = rememberLazyListState() val listState = rememberLazyListState()
val snackbarHostState = remember { SnackbarHostState() }
LaunchedEffect(messages.size) { LaunchedEffect(messages.size) {
if (messages.isNotEmpty()) { if (messages.isNotEmpty()) {
@@ -267,8 +232,16 @@ fun MessageScreen(channelId: Int, viewModel: ChatViewModel, onBack: () -> Unit)
} }
} }
LaunchedEffect(channelError) {
channelError?.let {
snackbarHostState.showSnackbar(it)
viewModel.clearChannelError()
}
}
Scaffold( Scaffold(
containerColor = MaterialTheme.colorScheme.background, containerColor = MaterialTheme.colorScheme.background,
snackbarHost = { SnackbarHost(snackbarHostState) },
topBar = { topBar = {
TopAppBar( TopAppBar(
title = { title = {
@@ -294,6 +267,7 @@ fun MessageScreen(channelId: Int, viewModel: ChatViewModel, onBack: () -> Unit)
) )
} }
}, },
windowInsets = androidx.compose.foundation.layout.WindowInsets(0, 0, 0, 0),
colors = TopAppBarDefaults.topAppBarColors( colors = TopAppBarDefaults.topAppBarColors(
containerColor = Color.Transparent containerColor = Color.Transparent
) )
@@ -306,7 +280,7 @@ fun MessageScreen(channelId: Int, viewModel: ChatViewModel, onBack: () -> Unit)
modifier = Modifier.weight(1f).padding(horizontal = 16.dp), modifier = Modifier.weight(1f).padding(horizontal = 16.dp),
verticalArrangement = Arrangement.spacedBy(10.dp) verticalArrangement = Arrangement.spacedBy(10.dp)
) { ) {
items(messages) { message -> items(messages, key = { it.id }) { message ->
MessageBubble(message, currentUserId) MessageBubble(message, currentUserId)
} }
item { Spacer(Modifier.height(10.dp)) } item { Spacer(Modifier.height(10.dp)) }
@@ -391,10 +365,13 @@ fun MessageScreen(channelId: Int, viewModel: ChatViewModel, onBack: () -> Unit)
} }
} }
@OptIn(ExperimentalTime::class)
@Composable @Composable
fun MessageBubble(message: Message, currentUserId: Int?) { fun MessageBubble(message: Message, currentUserId: Int?) {
val time = remember(message.timestamp) { val time = remember(message.timestamp) {
DateFormat.getTimeInstance(DateFormat.SHORT).format(Date(message.timestamp)).lowercase() DateFormat.getTimeInstance(DateFormat.SHORT)
.format(Date(message.timestamp.toEpochMilliseconds()))
.lowercase()
} }
val isMe = currentUserId != null && message.user_id == currentUserId val isMe = currentUserId != null && message.user_id == currentUserId
@@ -404,7 +381,7 @@ fun MessageBubble(message: Message, currentUserId: Int?) {
horizontalAlignment = if (isMe) Alignment.End else Alignment.Start horizontalAlignment = if (isMe) Alignment.End else Alignment.Start
) { ) {
Surface( Surface(
color = if (isMe) MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f) else MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.2f), color = if (isMe) MaterialTheme.colorScheme.surfaceVariant else MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.2f),
shape = RoundedCornerShape( shape = RoundedCornerShape(
topStart = 14.dp, topStart = 14.dp,
topEnd = 14.dp, topEnd = 14.dp,
@@ -414,14 +391,7 @@ fun MessageBubble(message: Message, currentUserId: Int?) {
border = border(0.5.dp, MaterialTheme.colorScheme.outlineVariant.copy(alpha = 0.5f)) border = border(0.5.dp, MaterialTheme.colorScheme.outlineVariant.copy(alpha = 0.5f))
) { ) {
Column(modifier = Modifier.padding(horizontal = 11.dp, vertical = 8.dp)) { Column(modifier = Modifier.padding(horizontal = 11.dp, vertical = 8.dp)) {
if (!isMe) {
Text(
message.display_name?.lowercase() ?: "unknown",
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.primary.copy(alpha = 0.7f),
modifier = Modifier.padding(bottom = 2.dp)
)
}
Text( Text(
text = message.text, text = message.text,
style = MaterialTheme.typography.bodyLarge, style = MaterialTheme.typography.bodyLarge,
@@ -429,10 +399,11 @@ fun MessageBubble(message: Message, currentUserId: Int?) {
) )
} }
} }
Text( Text(
text = time, text = if (!isMe) message.display_name.lowercase() + " . " + time else time,
style = MaterialTheme.typography.labelSmall, style = MaterialTheme.typography.labelMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.3f), color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.5f),
modifier = Modifier.padding(horizontal = 4.dp, vertical = 2.dp) modifier = Modifier.padding(horizontal = 4.dp, vertical = 2.dp)
) )
} }
@@ -0,0 +1,52 @@
package dev.zxq5.chatapp.android.feature.contacts
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.padding
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Scaffold
import androidx.compose.material3.Text
import androidx.compose.material3.TopAppBar
import androidx.compose.material3.TopAppBarDefaults
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun ContactsScreen() {
Scaffold(
containerColor = MaterialTheme.colorScheme.background,
topBar = {
TopAppBar(
title = {
Text(
"contacts",
style = MaterialTheme.typography.titleLarge,
color = MaterialTheme.colorScheme.onSurface
)
},
windowInsets = androidx.compose.foundation.layout.WindowInsets(0, 0, 0, 0),
colors = TopAppBarDefaults.topAppBarColors(
containerColor = Color.Transparent,
titleContentColor = MaterialTheme.colorScheme.onSurface
)
)
}
) { padding ->
Box(
modifier = Modifier
.fillMaxSize()
.padding(padding),
contentAlignment = Alignment.Center
) {
Text(
text = "Contacts coming soon",
style = MaterialTheme.typography.bodyLarge,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
}
@@ -2,6 +2,7 @@ package dev.zxq5.chatapp.android.feature.settings
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import dev.zxq5.chatapp.android.api.model.InviteRequest
import dev.zxq5.chatapp.android.api.model.QrResponse import dev.zxq5.chatapp.android.api.model.QrResponse
import dev.zxq5.chatapp.android.core.error.ApiResult import dev.zxq5.chatapp.android.core.error.ApiResult
import dev.zxq5.chatapp.android.data.repository.SettingsRepository import dev.zxq5.chatapp.android.data.repository.SettingsRepository
@@ -27,6 +28,9 @@ class SettingsViewModel(private val settingsRepository: SettingsRepository) : Vi
private val _isSuccessState = MutableStateFlow<Map<String, Boolean>>(emptyMap()) private val _isSuccessState = MutableStateFlow<Map<String, Boolean>>(emptyMap())
val isSuccessState: StateFlow<Map<String, Boolean>> = _isSuccessState val isSuccessState: StateFlow<Map<String, Boolean>> = _isSuccessState
private val _lastInviteCode = MutableStateFlow<String?>(null)
val lastInviteCode: StateFlow<String?> = _lastInviteCode
fun clearMessages() { fun clearMessages() {
_settingsError.value = null _settingsError.value = null
_totpError.value = null _totpError.value = null
@@ -40,6 +44,20 @@ class SettingsViewModel(private val settingsRepository: SettingsRepository) : Vi
} }
} }
fun createInvite(request: InviteRequest) {
viewModelScope.launch {
_settingsError.value = null
when (val result = settingsRepository.createInvite(request)) {
is ApiResult.Success -> {
_lastInviteCode.value = result.data
triggerSuccess("invite")
}
is ApiResult.HttpError -> _settingsError.value = result.message
is ApiResult.NetworkError -> _settingsError.value = result.message
}
}
}
fun fetchTotpStatus() { fun fetchTotpStatus() {
viewModelScope.launch { viewModelScope.launch {
when (val result = settingsRepository.getTotpStatus()) { when (val result = settingsRepository.getTotpStatus()) {
@@ -23,11 +23,14 @@ import androidx.compose.foundation.text.KeyboardOptions
import androidx.compose.foundation.verticalScroll import androidx.compose.foundation.verticalScroll
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.filled.ArrowBack import androidx.compose.material.icons.automirrored.filled.ArrowBack
import androidx.compose.material.icons.filled.ContentCopy
import androidx.compose.material.icons.filled.KeyboardArrowDown import androidx.compose.material.icons.filled.KeyboardArrowDown
import androidx.compose.material.icons.filled.KeyboardArrowUp import androidx.compose.material.icons.filled.KeyboardArrowUp
import androidx.compose.material3.Button import androidx.compose.material3.Button
import androidx.compose.material3.ButtonDefaults import androidx.compose.material3.ButtonDefaults
import androidx.compose.material3.CircularProgressIndicator import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.DatePicker
import androidx.compose.material3.DatePickerDialog
import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.HorizontalDivider import androidx.compose.material3.HorizontalDivider
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
@@ -37,8 +40,10 @@ import androidx.compose.material3.OutlinedTextField
import androidx.compose.material3.OutlinedTextFieldDefaults import androidx.compose.material3.OutlinedTextFieldDefaults
import androidx.compose.material3.Scaffold import androidx.compose.material3.Scaffold
import androidx.compose.material3.Text import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.material3.TopAppBar import androidx.compose.material3.TopAppBar
import androidx.compose.material3.TopAppBarDefaults import androidx.compose.material3.TopAppBarDefaults
import androidx.compose.material3.rememberDatePickerState
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState import androidx.compose.runtime.collectAsState
@@ -57,13 +62,18 @@ import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp import androidx.compose.ui.unit.sp
import android.util.Base64 import android.util.Base64
import android.graphics.BitmapFactory import android.graphics.BitmapFactory
import androidx.compose.ui.platform.LocalClipboardManager
import androidx.compose.ui.text.AnnotatedString
import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.text.style.TextAlign
import dev.zxq5.chatapp.android.api.model.InviteRequest
import kotlin.time.Duration.Companion.days
import kotlin.time.ExperimentalTime
import kotlin.time.Instant
@OptIn(ExperimentalMaterial3Api::class) @OptIn(ExperimentalMaterial3Api::class, ExperimentalTime::class)
@Composable @Composable
fun SettingsScreen( fun SettingsScreen(
viewModel: SettingsViewModel, viewModel: SettingsViewModel,
onBack: () -> Unit,
onLogout: () -> Unit onLogout: () -> Unit
) { ) {
val is2faEnabled by viewModel.is2faEnabled.collectAsState() val is2faEnabled by viewModel.is2faEnabled.collectAsState()
@@ -71,6 +81,7 @@ fun SettingsScreen(
val settingsError by viewModel.settingsError.collectAsState() val settingsError by viewModel.settingsError.collectAsState()
val isSuccessState by viewModel.isSuccessState.collectAsState() val isSuccessState by viewModel.isSuccessState.collectAsState()
val totpError by viewModel.totpError.collectAsState() val totpError by viewModel.totpError.collectAsState()
val lastInviteCode by viewModel.lastInviteCode.collectAsState()
LaunchedEffect(Unit) { LaunchedEffect(Unit) {
viewModel.clearMessages() viewModel.clearMessages()
@@ -88,15 +99,7 @@ fun SettingsScreen(
color = MaterialTheme.colorScheme.onSurface color = MaterialTheme.colorScheme.onSurface
) )
}, },
navigationIcon = { windowInsets = androidx.compose.foundation.layout.WindowInsets(0, 0, 0, 0),
IconButton(onClick = onBack) {
Icon(
Icons.AutoMirrored.Filled.ArrowBack,
contentDescription = "Back",
tint = MaterialTheme.colorScheme.onSurfaceVariant
)
}
},
colors = TopAppBarDefaults.topAppBarColors(containerColor = Color.Transparent) colors = TopAppBarDefaults.topAppBarColors(containerColor = Color.Transparent)
) )
} }
@@ -283,6 +286,120 @@ fun SettingsScreen(
} }
} }
SettingsSection(title = "invite") {
var inviteName by remember { mutableStateOf("") }
var maxUses by remember { mutableStateOf("1") }
val clipboardManager = LocalClipboardManager.current
var showDatePicker by remember { mutableStateOf(false) }
val datePickerState = rememberDatePickerState(
initialSelectedDateMillis = System.currentTimeMillis() + 7.days.inWholeMilliseconds
)
Text("create invite token", style = MaterialTheme.typography.bodyMedium, modifier = Modifier.padding(bottom = 8.dp))
OutlinedTextField(
value = inviteName,
onValueChange = { inviteName = it },
label = { Text("name") },
modifier = Modifier.fillMaxWidth(),
shape = RoundedCornerShape(8.dp)
)
Spacer(Modifier.height(8.dp))
OutlinedTextField(
value = maxUses,
onValueChange = { if (it.all { c -> c.isDigit() }) maxUses = it },
label = { Text("max uses") },
keyboardOptions = KeyboardOptions(keyboardType = KeyboardType.Number),
modifier = Modifier.fillMaxWidth(),
shape = RoundedCornerShape(8.dp)
)
Spacer(Modifier.height(8.dp))
OutlinedTextField(
value = datePickerState.selectedDateMillis?.let { Instant.fromEpochMilliseconds(it).toString().substringBefore("T") } ?: "",
onValueChange = {},
label = { Text("expiry date") },
readOnly = true,
trailingIcon = {
IconButton(onClick = { showDatePicker = true }) {
Icon(Icons.Default.KeyboardArrowDown, contentDescription = "Select Date")
}
},
modifier = Modifier.fillMaxWidth().clickable { showDatePicker = true },
shape = RoundedCornerShape(8.dp)
)
if (showDatePicker) {
DatePickerDialog(
onDismissRequest = { showDatePicker = false },
confirmButton = {
TextButton(onClick = { showDatePicker = false }) {
Text("ok")
}
}
) {
DatePicker(state = datePickerState)
}
}
Spacer(Modifier.height(12.dp))
SuccessButton(
onClick = {
val nowMs = System.currentTimeMillis()
val expiryMs = datePickerState.selectedDateMillis ?: (nowMs + 7.days.inWholeMilliseconds)
viewModel.createInvite(
InviteRequest(
name = inviteName,
max_uses = maxUses.toIntOrNull() ?: 1,
start_date = Instant.fromEpochMilliseconds(nowMs),
expiry_date = Instant.fromEpochMilliseconds(expiryMs)
)
)
},
label = "generate invite",
isSuccess = isSuccessState["invite"] == true,
enabled = inviteName.isNotBlank(),
modifier = Modifier.fillMaxWidth()
)
if (lastInviteCode != null) {
Spacer(Modifier.height(16.dp))
Row(
modifier = Modifier
.fillMaxWidth()
.background(MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.3f), RoundedCornerShape(8.dp))
.padding(12.dp),
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.SpaceBetween
) {
Text(
text = lastInviteCode!!,
style = MaterialTheme.typography.bodyLarge,
modifier = Modifier.weight(1f)
)
IconButton(onClick = {
clipboardManager.setText(AnnotatedString(lastInviteCode!!))
}) {
Icon(Icons.Default.ContentCopy, contentDescription = "Copy", modifier = Modifier.size(20.dp))
}
}
}
}
SettingsSection(title = "session") {
Button(
onClick = onLogout,
modifier = Modifier.fillMaxWidth(),
shape = RoundedCornerShape(8.dp),
colors = ButtonDefaults.buttonColors(containerColor = Color.White, contentColor = Color.Black)
) {
Text("logout")
}
}
SettingsSection(title = "danger zone", color = Color.Red.copy(alpha = 0.7f)) { SettingsSection(title = "danger zone", color = Color.Red.copy(alpha = 0.7f)) {
var deletePassword by remember { mutableStateOf("") } var deletePassword by remember { mutableStateOf("") }
var deleteTotp by remember { mutableStateOf("") } var deleteTotp by remember { mutableStateOf("") }
@@ -346,18 +463,6 @@ fun SettingsScreen(
} }
} }
SettingsSection(title = "session") {
Spacer(Modifier.height(16.dp))
Button(
onClick = onLogout,
modifier = Modifier.fillMaxWidth(),
shape = RoundedCornerShape(8.dp),
colors = ButtonDefaults.buttonColors(containerColor = Color.White, contentColor = Color.Black)
) {
Text("logout")
}
}
if (settingsError != null) { if (settingsError != null) {
Text(settingsError!!, color = Color.Red, style = MaterialTheme.typography.bodySmall, modifier = Modifier.padding(top = 8.dp)) Text(settingsError!!, color = Color.Red, style = MaterialTheme.typography.bodySmall, modifier = Modifier.padding(top = 8.dp))
} }
@@ -466,6 +571,7 @@ fun SuccessButton(
} }
} }
@OptIn(ExperimentalTime::class)
@Composable @Composable
fun TwoFactorSetup( fun TwoFactorSetup(
qrCodeBase64: String?, qrCodeBase64: String?,
@@ -520,15 +626,13 @@ fun TwoFactorSetup(
Text(error.lowercase(), color = Color.Red, style = MaterialTheme.typography.labelSmall, modifier = Modifier.padding(top = 8.dp)) Text(error.lowercase(), color = Color.Red, style = MaterialTheme.typography.labelSmall, modifier = Modifier.padding(top = 8.dp))
} }
Spacer(Modifier.height(24.dp)) Spacer(Modifier.height(16.dp))
SuccessButton(
Button( onClick = { onConfirm(code) },
onClick = { if (code.length == 6) onConfirm(code) }, label = "verify and enable",
isSuccess = false, // Managed by parent
enabled = code.length == 6, enabled = code.length == 6,
modifier = Modifier.fillMaxWidth(), modifier = Modifier.fillMaxWidth()
shape = RoundedCornerShape(8.dp) )
) {
Text("confirm code")
}
} }
} }
@@ -2,12 +2,14 @@ package dev.zxq5.chatapp.android.ui.components
import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.foundation.text.KeyboardOptions
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.OutlinedTextField import androidx.compose.material3.OutlinedTextField
import androidx.compose.material3.OutlinedTextFieldDefaults import androidx.compose.material3.OutlinedTextFieldDefaults
import androidx.compose.material3.Text import androidx.compose.material3.Text
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.text.input.KeyboardType
import androidx.compose.ui.text.input.PasswordVisualTransformation import androidx.compose.ui.text.input.PasswordVisualTransformation
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
@@ -31,6 +33,11 @@ fun TextField(
modifier = Modifier.fillMaxWidth(), modifier = Modifier.fillMaxWidth(),
singleLine = true, singleLine = true,
textStyle = MaterialTheme.typography.bodyLarge, textStyle = MaterialTheme.typography.bodyLarge,
keyboardOptions = if (isPassword) {
KeyboardOptions(keyboardType = KeyboardType.Password)
} else {
KeyboardOptions.Default
},
visualTransformation = if (isPassword) PasswordVisualTransformation() else androidx.compose.ui.text.input.VisualTransformation.None, visualTransformation = if (isPassword) PasswordVisualTransformation() else androidx.compose.ui.text.input.VisualTransformation.None,
shape = RoundedCornerShape(8.dp), shape = RoundedCornerShape(8.dp),
colors = OutlinedTextFieldDefaults.colors( colors = OutlinedTextFieldDefaults.colors(
@@ -40,6 +47,6 @@ fun TextField(
unfocusedBorderColor = MaterialTheme.colorScheme.outline, unfocusedBorderColor = MaterialTheme.colorScheme.outline,
focusedTextColor = MaterialTheme.colorScheme.onSurface, focusedTextColor = MaterialTheme.colorScheme.onSurface,
unfocusedTextColor = MaterialTheme.colorScheme.onSurface unfocusedTextColor = MaterialTheme.colorScheme.onSurface
) ),
) )
} }
@@ -2,7 +2,7 @@ package dev.zxq5.chatapp.android.ui.theme
import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.Color
val Black = Color(0xFF0A0A0A) val Black = Color(0xFF000000)
val DarkGrey = Color(0xFF0D0D0D) val DarkGrey = Color(0xFF0D0D0D)
val Grey = Color(0xFF141414) val Grey = Color(0xFF141414)
val LightGrey = Color(0xFF1E1E1E) val LightGrey = Color(0xFF1E1E1E)
@@ -0,0 +1,5 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android" android:autoMirrored="true" android:height="24dp" android:tint="#FFFFFF" android:viewportHeight="24" android:viewportWidth="24" android:width="24dp">
<path android:fillColor="@android:color/white" android:pathData="M20,2L4,2c-1.1,0 -1.99,0.9 -1.99,2L2,22l4,-4h14c1.1,0 2,-0.9 2,-2L22,4c0,-1.1 -0.9,-2 -2,-2zM18,14L6,14v-2h12v2zM18,11L6,11L6,9h12v2zM18,8L6,8L6,6h12v2z"/>
</vector>
+10
View File
@@ -0,0 +1,10 @@
# Default ignored files
/shelf/
/workspace.xml
# Ignored default folder with query files
/queries/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml
# Editor-based HTTP Client requests
/httpRequests/
+6
View File
@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="uk.co.ben_gibson.git.link.SettingsState">
<option name="host" value="e0f86390-1091-4871-8aeb-f534fbc99cf0" />
</component>
</project>
+12
View File
@@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="EMPTY_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/tests" isTestSource="true" />
<excludeFolder url="file://$MODULE_DIR$/target" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>
+17
View File
@@ -0,0 +1,17 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
<data-source source="LOCAL" name="chatapp_dev@100.118.108.58" uuid="b14acf5d-6750-469b-8aea-59c8343eb11c">
<driver-ref>postgresql</driver-ref>
<synchronize>true</synchronize>
<jdbc-driver>org.postgresql.Driver</jdbc-driver>
<jdbc-url>jdbc:postgresql://100.118.108.58:5432/chatapp_dev</jdbc-url>
<jdbc-additional-properties>
<property name="com.intellij.clouds.kubernetes.db.host.port" />
<property name="com.intellij.clouds.kubernetes.db.enabled" value="false" />
<property name="com.intellij.clouds.kubernetes.db.container.port" />
</jdbc-additional-properties>
<working-dir>$ProjectFileDir$</working-dir>
</data-source>
</component>
</project>
+7
View File
@@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DataSourcePerFileMappings">
<file url="file://$PROJECT_DIR$/sql/test.sql" value="b14acf5d-6750-469b-8aea-59c8343eb11c" />
<file url="file://$PROJECT_DIR$/src/repo/user_repo.rs" value="b14acf5d-6750-469b-8aea-59c8343eb11c" />
</component>
</project>
+6
View File
@@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
</profile>
</component>
+8
View File
@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/backend.iml" filepath="$PROJECT_DIR$/.idea/backend.iml" />
</modules>
</component>
</project>
+8
View File
@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="SqlDialectMappings">
<file url="file://$PROJECT_DIR$/migrations/20260412200102_message_id_to_uuid.sql" dialect="PostgreSQL" />
<file url="file://$PROJECT_DIR$/sql/test.sql" dialect="PostgreSQL" />
<file url="PROJECT" dialect="PostgreSQL" />
</component>
</project>
+6
View File
@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$/.." vcs="Git" />
</component>
</project>
+7 -4
View File
@@ -10,9 +10,9 @@ dotenv = "0.15.0"
futures-util = "0.3.31" futures-util = "0.3.31"
image = "0.25.8" image = "0.25.8"
jsonwebtoken = { version = "10.3.0", features = ["rust_crypto"] } jsonwebtoken = { version = "10.3.0", features = ["rust_crypto"] }
rand = "0.9.2" rand = "0.8"
redis = { version = "0.25.4", features = ["tokio-comp"] } redis = { version = "0.25.4", features = ["tokio-comp"] }
reqwest = { version = "0.12.23", features = ["json"] } reqwest = { version = "0.12.23", features = ["json", "stream"] }
rocket = { version = "0.5.1", features = ["json", "secrets"] } rocket = { version = "0.5.1", features = ["json", "secrets"] }
rocket_cors = "0.6.0" rocket_cors = "0.6.0"
rocket_db_pools = { version = "0.2.0", features = ["deadpool_redis", "sqlx_macros", "sqlx_postgres"] } rocket_db_pools = { version = "0.2.0", features = ["deadpool_redis", "sqlx_macros", "sqlx_postgres"] }
@@ -20,8 +20,11 @@ rocket_dyn_templates = { version = "0.2.0", features = ["tera"] }
serde = { version = "1.0.228", features = ["derive"] } serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0.145" serde_json = "1.0.145"
sha2 = "0.10.9" sha2 = "0.10.9"
sqlx = { version = "0.7.4", features = ["macros", "time"] } sqlx = { version = "0.7.4", features = ["chrono", "macros", "postgres", "time", "uuid"] }
tokio = { version = "1.47.1", features = ["full"] } tokio = { version = "1.47.1", features = ["full"] }
totp-rs = { version = "5.7.0", features = ["gen_secret", "qr", "rand"] } totp-rs = { version = "5.7.0", features = ["gen_secret", "qr", "rand"] }
tracing = "0.1.44" tracing = "0.1.44"
uuid = { version = "1.18.1", features = ["v4"] } uuid = { version = "1.18.1", features = ["serde", "v4"] }
thiserror = "1.0.69"
utoipa = { version = "5.4.0", features = ["rocket_extras", "chrono"] }
clap = { version = "4.5", features = ["derive"] }
-4
View File
@@ -12,8 +12,6 @@ COPY cdn cdn
COPY src src COPY src src
COPY Cargo.toml Cargo.toml COPY Cargo.toml Cargo.toml
COPY Rocket.toml Rocket.toml COPY Rocket.toml Rocket.toml
COPY static static
COPY templates templates
RUN apt-get update && apt-get install -y libssl-dev pkg-config RUN apt-get update && apt-get install -y libssl-dev pkg-config
@@ -37,9 +35,7 @@ COPY --from=build /build/main ./
## copy runtime assets which may or may not exist ## copy runtime assets which may or may not exist
COPY --from=build /build/Rocket.toml ./Rocket.toml COPY --from=build /build/Rocket.toml ./Rocket.toml
COPY --from=build /build/static ./static
COPY --from=build /build/cdn ./cdn COPY --from=build /build/cdn ./cdn
COPY --from=build /build/template[s] ./templates
## ensure the container listens globally on port 8000 ## ensure the container listens globally on port 8000
ENV ROCKET_ADDRESS=0.0.0.0 ENV ROCKET_ADDRESS=0.0.0.0
+1 -1
View File
@@ -1,7 +1,7 @@
[debug] [debug]
secret_key = "yYhvCGnRh/TrcHtB8sZqCFifrVmJxoKFLBYw/WWBZeU=" secret_key = "yYhvCGnRh/TrcHtB8sZqCFifrVmJxoKFLBYw/WWBZeU="
address = "0.0.0.0" address = "0.0.0.0"
port = 8000 port = 8080
[debug.databases.postgres_db] [debug.databases.postgres_db]
url = "postgresql://chatapp:chatapp@100.118.108.58:5432/chatapp_dev" url = "postgresql://chatapp:chatapp@100.118.108.58:5432/chatapp_dev"
@@ -1,17 +1,18 @@
services: services:
backend: backend:
container_name: chatapp_backend
build: build:
context: ./backend context: .
args: args:
- DATABASE_URL=${DATABASE_URL} - DATABASE_URL=${DATABASE_URL}
ports: ports:
- "8000:8000" - "8000:8000"
depends_on: depends_on:
- redis - redis
environment: env_file:
- ROCKET_SECRET_KEY=${ROCKET_SECRET_KEY} - .env
- DATABASE_URL=${DATABASE_URL}
redis: redis:
container_name: chatapp_redis
image: docker.io/library/redis:alpine image: docker.io/library/redis:alpine
ports: ports:
- "6379:6379" - "6379:6379"
@@ -1,14 +1,16 @@
services: services:
backend: backend:
container_name: chatapp_backend container_name: chatapp_backend
image: git.zxq5.dev/zxq5/chatapp-backend:latest image: git.zxq5.dev/zxq5/chatapp-backend:v0.4.1
ports: ports:
- "8080:8000" - "8000:8000"
depends_on: depends_on:
- redis - redis
env_file:
- .env
environment: environment:
- ROCKET_SECRET_KEY=${ROCKET_SECRET_KEY} - RELEASE_MODE=1
- DATABASE_URL=${DATABASE_URL}
redis: redis:
container_name: chatapp_redis container_name: chatapp_redis
image: docker.io/library/redis:alpine image: docker.io/library/redis:alpine
@@ -1,49 +0,0 @@
-- Add migration script here
CREATE TABLE users (
id SERIAL PRIMARY KEY,
username VARCHAR(50) UNIQUE NOT NULL,
password VARCHAR(50) NOT NULL,
display_name VARCHAR(50),
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE channels (
id SERIAL PRIMARY KEY,
name VARCHAR(50) NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE messages (
id SERIAL PRIMARY KEY,
channel_id INTEGER NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
content TEXT NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
is_edited BOOLEAN DEFAULT FALSE
);
create table attachments (
id SERIAL PRIMARY KEY,
message_id INTEGER NOT NULL REFERENCES messages(id) ON DELETE CASCADE,
path TEXT NOT NULL
);
CREATE INDEX idx_messages_channel_id ON messages (channel_id, id DESC);
CREATE INDEX idx_new_messages ON messages(created_at DESC);
-- Create a function to update the updated_at timestamp
CREATE OR REPLACE FUNCTION update_updated_at_column()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = CURRENT_TIMESTAMP;
RETURN NEW;
END;
$$ language 'plpgsql';
-- Create trigger for messages table
CREATE TRIGGER update_messages_updated_at
BEFORE UPDATE ON messages
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
@@ -1,20 +0,0 @@
-- Add migration script here
CREATE TABLE sessions (
id SERIAL PRIMARY KEY,
user_id INTEGER NOT NULL REFERENCES users(id),
token TEXT NOT NULL UNIQUE,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + INTERVAL '7 days'
);
CREATE OR REPLACE FUNCTION cleanup_expired_sessions()
RETURNS TRIGGER AS $$
BEGIN
DELETE FROM sessions WHERE expires_at < NOW();
RETURN NULL;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER trigger_cleanup_sessions
AFTER INSERT ON sessions
EXECUTE FUNCTION cleanup_expired_sessions();
@@ -1,5 +0,0 @@
-- Add migration script here
ALTER TABLE users ADD COLUMN email VARCHAR(100);
ALTER TABLE users ADD COLUMN twofa_enabled BOOLEAN DEFAULT FALSE;
ALTER TABLE users ADD COLUMN totp_secret VARCHAR(64);
ALTER TABLE users ADD COLUMN updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP;
@@ -1,2 +0,0 @@
-- Add migration script here
ALTER TABLE users ALTER COLUMN twofa_enabled SET NOT NULL;
@@ -1,18 +0,0 @@
-- Add migration script here
CREATE TABLE access_codes (
-- identifiers
id SERIAL PRIMARY KEY,
creator_id INTEGER NOT NULL REFERENCES users(id),
-- code data
code VARCHAR(255) NOT NULL,
name VARCHAR(255) NOT NULL,
-- uses
uses INTEGER NOT NULL DEFAULT 0,
max_uses INTEGER NOT NULL DEFAULT 1,
-- time data
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + INTERVAL '1 day'
);
@@ -1,10 +0,0 @@
-- Add migration script here
ALTER TABLE access_codes
ALTER COLUMN created_at
TYPE TIMESTAMP WITH TIME ZONE
USING created_at AT TIME ZONE 'UTC';
ALTER TABLE access_codes
ALTER COLUMN expires_at
TYPE TIMESTAMP WITH TIME ZONE
USING expires_at AT TIME ZONE 'UTC';
@@ -1,6 +0,0 @@
-- Add migration script here
TRUNCATE TABLE users CASCADE;
ALTER TABLE users DROP COLUMN password;
ALTER TABLE users ADD COLUMN pass_hash VARCHAR(255) NOT NULL;
ALTER TABLE users ADD CONSTRAINT users_username_unique UNIQUE (username);
@@ -1,13 +0,0 @@
-- Add migration script here
CREATE TYPE status AS ENUM ('pending', 'accepted', 'blocked');
CREATE TABLE relationships (
id SERIAL PRIMARY KEY,
from_user INTEGER NOT NULL REFERENCES users(id),
to_user INTEGER NOT NULL REFERENCES users(id),
status status NOT NULL DEFAULT 'pending',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT no_self_relationship CHECK (from_user != to_user),
CONSTRAINT unique_relationship UNIQUE (from_user, to_user)
);
@@ -0,0 +1,183 @@
-- Add migration script here
-- Add migration script here
CREATE TYPE user_role AS ENUM ('user', 'admin');
CREATE TYPE totp_status AS ENUM ('disabled', 'pending', 'enabled');
CREATE TABLE users (
id BIGSERIAL PRIMARY KEY NOT NULL,
role user_role NOT NULL DEFAULT 'user',
-- profile
nickname VARCHAR(255),
-- basic auth
username VARCHAR(255) UNIQUE NOT NULL,
passhash VARCHAR(255) NOT NULL,
-- email
email VARCHAR(255),
email_verified BOOLEAN DEFAULT FALSE,
-- 2fa
totp_secret VARCHAR(255),
totp_status totp_status NOT NULL DEFAULT 'disabled',
-- update tracking
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
deleted_at TIMESTAMPTZ
);
CREATE TABLE access_tokens (
id BIGSERIAL PRIMARY KEY NOT NULL,
creator_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
code VARCHAR(255) NOT NULL,
name VARCHAR(255) NOT NULL,
uses INTEGER NOT NULL DEFAULT 0,
max_uses INTEGER NOT NULL DEFAULT 1,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
expires_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + INTERVAL '24 hours',
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE refresh_tokens (
id BIGSERIAL PRIMARY KEY NOT NULL,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
token_hash VARCHAR(255) NOT NULL,
revoked BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
expires_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + INTERVAL '7 days'
);
CREATE TABLE spaces (
id BIGSERIAL PRIMARY KEY NOT NULL,
name VARCHAR(255) NOT NULL,
description TEXT,
owner_id BIGINT NOT NULL REFERENCES users(id),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE channels (
id BIGSERIAL PRIMARY KEY NOT NULL,
name VARCHAR(255) NOT NULL,
description TEXT,
space_id BIGINT NOT NULL REFERENCES spaces(id) ON DELETE CASCADE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE space_members (
space_id BIGINT NOT NULL REFERENCES spaces(id) ON DELETE CASCADE,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
joined_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
role user_role DEFAULT 'user',
PRIMARY KEY (space_id, user_id)
);
CREATE TABLE messages (
id BIGSERIAL PRIMARY KEY NOT NULL,
channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
content TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
is_edited BOOLEAN DEFAULT FALSE
);
CREATE TABLE attachments (
id BIGSERIAL PRIMARY KEY NOT NULL,
message_id BIGINT NOT NULL REFERENCES messages(id) ON DELETE CASCADE,
filename VARCHAR(255) NOT NULL,
content_type VARCHAR(100) NOT NULL, -- mime type e.g. image/png, video/mp4
size_bytes BIGINT NOT NULL,
url TEXT NOT NULL, -- path to file on your CDN/storage
width INTEGER, -- null for non-image/video
height INTEGER, -- null for non-image/video
duration_ms INTEGER, -- null for non-audio/video
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TYPE relationship_status AS ENUM ('pending', 'accepted', 'blocked');
CREATE TABLE relationships (
id BIGSERIAL PRIMARY KEY NOT NULL,
from_user BIGINT NOT NULL REFERENCES users(id),
to_user BIGINT NOT NULL REFERENCES users(id),
status relationship_status NOT NULL DEFAULT 'pending',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT no_self_relationship CHECK (from_user != to_user),
CONSTRAINT unique_relationship UNIQUE (from_user, to_user)
);
CREATE INDEX idx_messages_channel_id ON messages(channel_id, created_at DESC);
CREATE INDEX idx_messages_user_id ON messages(user_id);
CREATE INDEX idx_attachments_message ON attachments(message_id);
CREATE INDEX idx_channels_space_id ON channels(space_id);
CREATE INDEX idx_space_members_user ON space_members(user_id);
CREATE INDEX idx_refresh_tokens_hash ON refresh_tokens(token_hash);
CREATE INDEX idx_relationships_from ON relationships(from_user, to_user);
CREATE INDEX idx_relationships_to ON relationships(to_user);
CREATE INDEX idx_access_tokens_code ON access_tokens(code);
CREATE OR REPLACE FUNCTION update_updated_at()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = NOW();
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER users_updated_at
BEFORE UPDATE ON users
FOR EACH ROW EXECUTE FUNCTION update_updated_at();
CREATE TRIGGER spaces_updated_at
BEFORE UPDATE ON spaces
FOR EACH ROW EXECUTE FUNCTION update_updated_at();
CREATE TRIGGER channels_updated_at
BEFORE UPDATE ON channels
FOR EACH ROW EXECUTE FUNCTION update_updated_at();
CREATE TRIGGER messages_updated_at
BEFORE UPDATE ON messages
FOR EACH ROW EXECUTE FUNCTION update_updated_at();
CREATE TRIGGER relationships_updated_at
BEFORE UPDATE ON relationships
FOR EACH ROW EXECUTE FUNCTION update_updated_at();
CREATE TRIGGER access_tokens_updated_at
BEFORE UPDATE ON access_tokens
FOR EACH ROW EXECUTE FUNCTION update_updated_at();
CREATE OR REPLACE FUNCTION add_owner_to_space()
RETURNS TRIGGER AS $$
BEGIN
INSERT INTO space_members (space_id, user_id, role, joined_at)
VALUES (NEW.id, NEW.owner_id, 'admin', NOW());
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER space_owner_becomes_member
AFTER INSERT ON spaces
FOR EACH ROW EXECUTE FUNCTION add_owner_to_space();
@@ -0,0 +1,9 @@
ALTER TABLE attachments DROP CONSTRAINT attachments_message_id_fkey;
ALTER TABLE messages ALTER COLUMN id DROP DEFAULT;
ALTER TABLE messages ALTER COLUMN id TYPE uuid USING gen_random_uuid();
ALTER TABLE messages ALTER COLUMN id SET DEFAULT gen_random_uuid();
ALTER TABLE attachments ALTER COLUMN message_id TYPE uuid USING gen_random_uuid();
ALTER TABLE attachments ADD CONSTRAINT attachments_message_id_fkey
FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE CASCADE;
+17
View File
@@ -0,0 +1,17 @@
WITH space1 AS (
INSERT INTO spaces (name, description, owner_id)
VALUES ('general', 'Boring chat idk', 1)
RETURNING id
),
space2 AS (
INSERT INTO spaces (name, description, owner_id)
VALUES ('Gaming', 'we lose games', 1)
RETURNING id
)
INSERT INTO channels (name, description, space_id)
SELECT 'General', 'General chat', id FROM space1 UNION ALL
SELECT 'Coding', 'Coding stuff', id FROM space1 UNION ALL
SELECT 'AI', '"/ask" here pls :)', id FROM space1 UNION ALL
SELECT 'The Game', '(You lost)', id FROM space2 UNION ALL
SELECT 'Backrooms', 'Beware of Smilers', id FROM space2 UNION ALL
SELECT 'SE', 'Space/Software engineering.', id FROM space2;
+193
View File
@@ -0,0 +1,193 @@
use crate::error::ApiResult;
use crate::model::auth::{AccessTokenForm, AuthResponse, LoginCredentials, SignupCredentials};
use crate::svc::access_token_svc::AccessTokenService;
use crate::svc::auth_svc::AuthService;
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation, decode, encode};
use rocket::http::Status;
use rocket::request::{FromRequest, Outcome};
use rocket::serde::json::Json;
use rocket::serde::{Deserialize, Serialize};
use rocket::{Request, State};
use std::sync::LazyLock;
use std::time::{SystemTime, UNIX_EPOCH};
#[post("/signup", data = "<cred>")]
pub async fn signup(
cred: Json<SignupCredentials>,
svc: &State<AuthService>,
) -> ApiResult<Json<AuthResponse>> {
let response = svc
.signup(
&cred.email,
&cred.username,
&cred.password,
&cred.access_token,
)
.await?;
Ok(Json(response))
}
#[post("/login", data = "<cred>")]
pub async fn login(
cred: Json<LoginCredentials>,
svc: &State<AuthService>,
) -> ApiResult<Json<AuthResponse>> {
Ok(Json(svc.login(&cred.username, &cred.password).await?))
}
#[post("/invite", data = "<form>")]
pub async fn generate_invite(
session: AdminSession,
form: Json<AccessTokenForm>,
svc: &State<AccessTokenService>,
) -> ApiResult<String> {
svc.create(
session.uid,
&form.name,
form.max_uses,
form.start_date,
form.expiry_date,
)
.await
}
static JWT_SECRET: LazyLock<String> = LazyLock::new(|| std::env::var("JWT_SECRET").unwrap());
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Copy)]
#[serde(rename_all = "snake_case")]
pub enum TokenScope {
Full,
TotpPending,
}
pub struct Session {
pub uid: i64,
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for Session {
type Error = ();
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match Claims::from_request(req).await {
Outcome::Success(user) if user.scope == TokenScope::Full => Outcome::Success(Session {
uid: user.sub as i64,
}),
Outcome::Success(_) => {
eprintln!("warning: user with scope other than Full attempted to access session");
Outcome::Error((Status::Forbidden, ()))
}
Outcome::Error(err) => {
eprintln!("Session request guard failed: {:?}", err);
Outcome::Error(err)
}
_ => unreachable!("forward should never be called"),
}
}
}
pub struct AdminSession {
pub uid: i64,
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for AdminSession {
type Error = ();
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
// First verify the session is valid
match Claims::from_request(req).await {
Outcome::Success(user) if user.scope == TokenScope::Full => {
let uid = user.sub as i64;
// Get AuthService from Rocket state
let auth_svc = match req.guard::<&State<AuthService>>().await {
Outcome::Success(svc) => svc,
Outcome::Error(err) => {
tracing::error!("AdminSession: Failed to get AuthService from state");
return Outcome::Error(err);
}
_ => unreachable!("forward should never be called"),
};
// Check if user is admin
match auth_svc.is_admin(uid).await {
Ok(true) => Outcome::Success(AdminSession { uid }),
Ok(false) => {
tracing::debug!("non-admin user attempted to access admin session");
Outcome::Error((Status::Forbidden, ()))
}
Err(err) => {
tracing::error!("AdminSession: is_admin check failed: {:?}", err);
Outcome::Error((Status::InternalServerError, ()))
}
}
}
Outcome::Success(_) => {
tracing::debug!("warning: user with scope other than Full attempted to access admin session");
Outcome::Error((Status::Forbidden, ()))
}
Outcome::Error(err) => {
tracing::debug!("AdminSession request guard failed: {:?}", err);
Outcome::Error(err)
}
_ => unreachable!("forward should never be called"),
}
}
}
#[derive(Serialize, Deserialize)]
pub struct Claims {
pub sub: i32,
pub exp: usize,
pub scope: TokenScope,
}
impl Claims {
pub fn new(user_id: usize, scope: TokenScope) -> Self {
Self {
sub: user_id as i32,
exp: (SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Failed to get time")
.as_secs()
+ 60 * 60 * 24 * 7) as usize,
scope,
}
}
pub fn encode(&self) -> String {
encode(
&Header::default(),
self,
&EncodingKey::from_secret(JWT_SECRET.as_bytes()),
)
.expect("unable to encode jwt")
}
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for Claims {
type Error = ();
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let token = req
.headers()
.get_one("Authorization")
.and_then(|v| v.strip_prefix("Bearer "));
match token {
None => Outcome::Error((Status::Unauthorized, ())),
Some(t) => {
match decode::<Claims>(
t,
&DecodingKey::from_secret(JWT_SECRET.as_bytes()),
&Validation::default(),
) {
Ok(data) => Outcome::Success(data.claims),
Err(_) => Outcome::Error((Status::Unauthorized, ())),
}
}
}
}
}
@@ -26,7 +26,7 @@ pub async fn profile_pic(user_id: usize) -> Option<NamedFile> {
Some(image) Some(image)
} else { } else {
Some( Some(
NamedFile::open("./cdn/profiles/full/default.svg") NamedFile::open("../../cdn/profiles/full/default.svg")
.await .await
.ok()?, .ok()?,
) )
+63
View File
@@ -0,0 +1,63 @@
use crate::api::auth::Session;
use crate::error::ApiResult;
use crate::svc::chat_svc::ChatService;
use chrono::{DateTime, Utc};
use rocket::response::stream::Event;
use rocket::serde::json::Json;
use rocket::serde::{Deserialize, Serialize};
use rocket::{Shutdown, State, ___internal_EventStream as EventStream};
use sqlx::FromRow;
use tokio::select;
use tokio::sync::broadcast;
use crate::model::event::{ChatEvent, ChatMsg};
#[post("/chat/<channel_id>", format = "json", data = "<msg>")]
pub async fn post_message(
msg: Json<ChatMsg>,
chat: &State<ChatService>,
session: Session,
channel_id: i64,
) -> ApiResult<()> {
chat.send(channel_id, session.uid, &msg.text, Utc::now()).await
}
#[get("/events/<channel_id>")]
pub async fn event_stream(
chat: &State<ChatService>,
s: Session,
mut shutdown: Shutdown,
channel_id: i64,
) -> ApiResult<EventStream![]> {
let messages = chat.fetch_latest_messages_desc(channel_id, 100)
.await?; // if get message returned err, inform user.
let mut rx = chat.subscribe(channel_id).await;
let id = s.uid;
Ok(EventStream! {
for msg in messages.into_iter().rev() {
// tracing::info!("sending: {:?}", serde_json::to_string(&ChatEvent::SendMessage(msg.clone())).unwrap());
yield Event::json(&ChatEvent::SendMessage(msg));
}
loop {
select!{
_ = &mut shutdown => break, // exit early on shutdown
event = rx.recv() => match event {
Ok(event) => {
// tracing::info!("yielding event: {event:?}");
yield Event::json(&event)
},
Err(broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!("Receiver lagging on channel {channel_id} by {n} events",);
yield Event::comment("RecvError::Lagged");
}
Err(broadcast::error::RecvError::Closed) => {
tracing::info!("Broadcaster hung up on channel {channel_id}!");
break
},
},
}
}
})
}
+7
View File
@@ -0,0 +1,7 @@
pub mod auth;
pub mod chat;
pub mod totp;
pub mod settings;
pub mod cdn;
pub mod profile;
pub mod space;
+13
View File
@@ -0,0 +1,13 @@
use rocket::State;
use crate::api::auth::Session;
use crate::error::ApiResult;
use crate::svc::user_svc::UserService;
#[get("/users/<id>")]
pub async fn display_name(
id: i64,
_ag: Session,
svc: &State<UserService>,
) -> ApiResult<String> {
svc.get_username(id).await
}
+68
View File
@@ -0,0 +1,68 @@
use crate::api::auth::Session;
use crate::error::ApiResult;
use crate::svc::settings_svc::SettingsService;
use rocket::serde::json::Json;
use rocket::serde::{Deserialize, Serialize};
use rocket::State;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PasswordForm {
pub old_password: String,
pub new_password: String,
}
#[post("/settings/password", data = "<form>")]
pub async fn change_password(
session: Session,
form: Json<PasswordForm>,
settings: &State<SettingsService>
) -> ApiResult<()> {
settings.change_password(
session.uid, &form.old_password, &form.new_password
).await
}
#[derive(Deserialize, Debug, Clone)]
pub struct DisplayNameForm {
pub display_name: Option<String>,
}
#[derive(Deserialize, Debug, Clone)]
pub struct PasswordAnd2faForm {
pub password: String,
pub totp_code: Option<String>,
}
#[delete("/settings", data = "<data>")]
pub async fn delete_account(
session: Session,
data: Json<PasswordAnd2faForm>,
settings: &State<SettingsService>
) -> ApiResult<()> {
settings.delete_account(
session.uid, &data.password, &data.totp_code
).await
}
#[patch("/settings/display_name", data = "<new>")]
pub async fn change_display_name(
session: Session,
new: Json<DisplayNameForm>,
settings: &State<SettingsService>
) -> ApiResult<()> {
settings.change_display_name(session.uid, new.display_name.clone()).await
}
#[derive(Deserialize)]
pub struct UsernameForm {
pub username: String,
}
#[patch("/settings/username", data = "<new>")]
pub async fn change_username(
session: Session,
new: Json<UsernameForm>,
settings: &State<SettingsService>
) -> ApiResult<()> {
settings.change_username(session.uid, &new.username).await
}
+33
View File
@@ -0,0 +1,33 @@
use crate::api::auth::Session;
use crate::error::ApiResult;
use crate::model::space::Channel;
use crate::model::space::{Space, SpaceDto};
use crate::repo::{ChannelRepo, SpaceRepo};
use crate::svc::chat_svc::ChatService;
use rocket::State;
use rocket::serde::json::Json;
use std::sync::Arc;
#[get("/spaces")]
pub async fn list_spaces(space_repo: &State<Arc<dyn SpaceRepo>>) -> ApiResult<Json<Vec<Space>>> {
let spaces = space_repo.get_all().await?;
Ok(Json(spaces))
}
#[get("/spaces/<space_id>/channels")]
pub async fn list_channels(
space_id: i64,
channel_repo: &State<Arc<dyn ChannelRepo>>,
) -> ApiResult<Json<Vec<Channel>>> {
let channels = channel_repo.get_by_space_id(space_id).await?;
Ok(Json(channels))
}
#[get("/accessible_channels")]
pub async fn get_accessible_channels(
session: Session,
svc: &State<ChatService>,
) -> ApiResult<Json<Vec<SpaceDto>>> {
let space = svc.get_accessible_channels(session.uid).await?;
Ok(Json(space))
}
+120
View File
@@ -0,0 +1,120 @@
use crate::api::auth::{Claims, Session, TokenScope};
use crate::error::{ApiResult, AppError};
use crate::model::auth::AuthResponse;
use crate::svc::auth_svc::AuthService;
use rocket::serde::json::Json;
use rocket::serde::{Deserialize, Serialize};
use rocket::State;
use totp_rs::{Algorithm, TOTP};
#[derive(Debug, Deserialize)]
pub struct TOTPSixDigitCode {
code: String,
}
#[derive(Debug, sqlx::Type, Clone, Copy, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
#[sqlx(type_name = "totp_status", rename_all = "lowercase")]
pub enum TotpStatus {
Enabled,
Pending,
Disabled,
}
#[derive(Serialize)]
pub struct QrResponse {
qr_code: String,
}
#[derive(Deserialize)]
pub struct TotpVerifyRequest {
pub code: String,
}
#[derive(Deserialize)]
pub struct PasswordConfirmation {
password: String,
}
#[derive(Deserialize)]
pub struct PasswordAnd2fa {
pub password: String,
pub totp_code: String,
}
pub fn totp_gen(user_id: i64, secret: &[u8]) -> ApiResult<TOTP> {
TOTP::new(
Algorithm::SHA1,
6,
1,
30,
secret.to_owned(),
Some("chat.zxq5.dev".to_string()),
format!("{}", user_id),
)
.map_err(|_| AppError::internal("failed to generate totp"))
}
#[post("/totp", data = "<form>")]
pub async fn confirm_totp(
user: Session,
form: Json<TOTPSixDigitCode>,
svc: &State<AuthService>,
) -> ApiResult<()> {
svc.confirm_totp(user.uid, &form.code).await
}
#[post("/totp.jpg", data = "<form>")]
pub async fn get_totp(
user: Session,
form: Json<PasswordConfirmation>,
svc: &State<AuthService>,
) -> ApiResult<Json<QrResponse>> {
let secret = svc.get_or_create_totp_secret(user.uid, &form.password).await?;
let qr_b64 = totp_gen(user.uid, secret.as_bytes())
.map_err(|_| AppError::internal("invalid totp secret"))?
.get_qr_base64()
.map_err(|_| AppError::internal("failed to generate qr code"))?;
Ok(Json(QrResponse {
qr_code: format!("data:image/png;base64,{}", qr_b64),
}))
}
#[get("/totp/status")]
pub async fn get_totp_status(
user: Session,
svc: &State<AuthService>,
) -> ApiResult<Json<TotpStatus>> {
Ok(Json(
svc.get_totp_status(user.uid).await?
.then_some(TotpStatus::Enabled)
.unwrap_or(TotpStatus::Disabled),
))
}
#[delete("/totp", data = "<form>")]
pub async fn disable_totp(
user: Session,
form: Json<PasswordAnd2fa>,
svc: &State<AuthService>,
) -> ApiResult<Json<AuthResponse>> {
let response = svc.disable_totp(user.uid, &form.password, &form.totp_code).await?;
Ok(Json(response))
}
#[post("/totp/verify", data = "<body>")]
pub async fn verify_totp(
claims: Claims,
body: Json<TotpVerifyRequest>,
svc: &State<AuthService>,
) -> ApiResult<Json<AuthResponse>> {
// reject if they somehow got here with a full token
if claims.scope != TokenScope::TotpPending {
return Err(AppError::Forbidden);
}
let response = svc.login_totp(claims.sub as i64, &body.code).await?;
Ok(Json(response))
}
-211
View File
@@ -1,211 +0,0 @@
use argon2::{
Argon2, PasswordHash, PasswordHasher, PasswordVerifier,
password_hash::{SaltString, rand_core::OsRng},
};
use jsonwebtoken::{EncodingKey, Header, encode};
use rocket::{
http::{CookieJar, Status},
response::{Redirect, status::BadRequest},
serde::json::Json,
time::OffsetDateTime,
};
use rocket_db_pools::Connection;
use rocket_dyn_templates::{Template, context};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::{
auth::session::{Claims, Session, TokenScope},
db::Postgres,
user::User,
};
#[derive(Serialize, Deserialize)]
pub struct SignupCredentials {
pub email: String,
pub username: String,
pub password: String,
pub access_token: String,
}
#[derive(Serialize, Deserialize)]
pub struct LoginCredentials {
pub username: String,
pub password: String,
}
#[derive(Serialize, Deserialize)]
pub struct AuthResponse {
pub token: String,
pub totp_required: bool,
}
#[get("/signup")]
pub async fn signup_page() -> Template {
Template::render("signup", context!())
}
#[post("/signup", data = "<cred>")]
pub async fn signup(
cred: Json<SignupCredentials>,
jar: &CookieJar<'_>,
mut db: Connection<Postgres>,
) -> Result<Json<AuthResponse>, Status> {
let token_id = AccessToken::validate(&cred.access_token, &mut db)
.await
.map_err(|_| Status::Unauthorized)?;
let salt = SaltString::generate(&mut OsRng);
let hashed = Argon2::default()
.hash_password(cred.password.as_bytes(), &salt)
.map_err(|_| Status::InternalServerError)?
.to_string();
let result = sqlx::query!(
"INSERT INTO users (email, username, pass_hash) VALUES ($1, $2, $3) RETURNING id",
cred.email,
cred.username,
hashed,
)
.fetch_one(&mut **db)
.await
.map_err(|_| Status::InternalServerError)?;
let jwt = Claims::new(result.id as usize, TokenScope::Full).encode();
token_id
.use_token(&mut db)
.await
.expect("unable to use access code");
Ok(Json(AuthResponse {
token: jwt,
totp_required: false,
}))
}
#[get("/login")]
pub async fn login_page() -> Template {
Template::render("login", context!())
}
#[post("/login", data = "<cred>")]
pub async fn login(
mut db: Connection<Postgres>,
cred: Json<LoginCredentials>,
) -> Result<Json<AuthResponse>, Status> {
println!("e");
let row = sqlx::query!(
"SELECT id, pass_hash, twofa_enabled FROM users WHERE username = $1",
cred.username,
)
.fetch_one(&mut **db)
.await
.map_err(|_| Status::Unauthorized)?;
println!("ok");
// verify password as before
let parsed_hash = PasswordHash::new(&row.pass_hash).map_err(|_| Status::InternalServerError)?;
Argon2::default()
.verify_password(cred.password.as_bytes(), &parsed_hash)
.map_err(|_| Status::Unauthorized)?;
println!("ok2");
let user_id = row.id as usize;
// issue either a partial or full token depending on 2FA status
let (session, totp_required) = if row.twofa_enabled {
(Claims::new(user_id, TokenScope::TotpPending), true)
} else {
(Claims::new(user_id, TokenScope::Full), false)
};
Ok(Json(AuthResponse {
token: session.encode(),
totp_required,
}))
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessTokenForm {
pub name: String,
pub max_uses: usize,
pub expiry_date: usize,
pub start_date: usize,
}
#[get("/invite")]
pub async fn invite_page(_s: Session) -> Template {
Template::render("invite", context! {})
}
#[post("/invite", data = "<form>")]
pub async fn generate_invite(
session: Session,
mut db: Connection<Postgres>,
form: Json<AccessTokenForm>,
) -> Result<String, Status> {
if form.start_date > form.expiry_date {
return Err(Status::BadRequest);
}
let code = Uuid::new_v4().to_string();
sqlx::query!(
"INSERT INTO access_codes (name, code, creator_id, max_uses, created_at, expires_at)
VALUES ($1, $2, $3, $4, $5, $6) RETURNING id",
form.name,
code,
session.user_id as i32,
form.max_uses as i32,
OffsetDateTime::from_unix_timestamp_nanos(form.start_date as i128 * 1_000_000).unwrap(),
OffsetDateTime::from_unix_timestamp_nanos(form.expiry_date as i128 * 1_000_000).unwrap()
)
.fetch_one(&mut **db)
.await
.map_err(|_| Status::InternalServerError)?;
Ok(code)
}
pub struct AccessToken {
id: i32,
_code: String,
}
impl AccessToken {
pub async fn validate(
token: &str,
db: &mut Connection<Postgres>,
) -> Result<AccessToken, String> {
match sqlx::query!(
"SELECT id FROM access_codes
WHERE code = $1
AND created_at < NOW()
AND expires_at > NOW()
AND uses < max_uses",
token
)
.fetch_one(&mut ***db)
.await
{
Ok(row) => Ok(AccessToken {
id: row.id,
_code: token.to_string(),
}),
Err(_) => Err(String::from("Invalid or Expired token!")),
}
}
pub async fn use_token(&self, db: &mut Connection<Postgres>) -> Result<(), String> {
sqlx::query!(
"UPDATE access_codes SET uses = uses + 1 WHERE id = $1",
self.id
)
.execute(&mut ***db)
.await
.map_err(|_| String::from("Invalid or Expired token!"))?;
Ok(())
}
}
-12
View File
@@ -1,12 +0,0 @@
pub mod account;
pub mod profile;
pub mod session;
pub mod two_factor;
pub use session::Session;
pub use account::{generate_invite, invite_page, login, login_page, signup, signup_page};
pub use profile::{change_display_name, change_password, change_username, delete_account};
pub use two_factor::{
confirm_totp, disable_totp, get_totp, get_totp_status, mfa_page, verify_totp,
};
-143
View File
@@ -1,143 +0,0 @@
use argon2::{
Argon2, PasswordHash, PasswordHasher, PasswordVerifier,
password_hash::{SaltString, rand_core::OsRng},
};
use rocket::{http::Status, serde::json::Json};
use rocket_db_pools::Connection;
use serde::{Deserialize, Serialize};
use crate::{auth::Session, db::Postgres, user::User};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PasswordForm {
old_password: String,
new_password: String,
}
#[post("/settings/password", data = "<form>")]
pub async fn change_password(
session: Session,
mut db: Connection<Postgres>,
form: Json<PasswordForm>,
) -> Result<(), Status> {
let mut user = User::get_by_id(session.user_id, &mut db)
.await
.ok_or(Status::NotFound)
.inspect_err(|_| {
tracing::error!(
"Valid session does not have a valid user. ID: {}",
session.user_id
)
})?;
user.verify_password(&form.old_password)?;
// old password is correct, so new one can be set.
let salt = SaltString::generate(&mut OsRng);
let hashed = Argon2::default()
.hash_password(form.new_password.as_bytes(), &salt)
.inspect_err(|e| tracing::error!("failed to hash password! {e}"))
.map_err(|_| Status::InternalServerError)?
.to_string();
user.set_pass_hash(hashed, &mut db)
.await
.inspect_err(|e| tracing::error!("{e}"))
.map_err(|_| Status::InternalServerError)?;
Ok(())
}
#[derive(Deserialize, Debug, Clone)]
pub struct DisplayNameForm {
pub display_name: Option<String>,
}
#[derive(Deserialize, Debug, Clone)]
pub struct PasswordAnd2fa {
pub password: String,
pub totp_code: Option<String>,
}
#[delete("/settings", data = "<data>")]
pub async fn delete_account(
session: Session,
mut db: Connection<Postgres>,
data: Json<PasswordAnd2fa>,
) -> Result<(), Status> {
let mut user = User::get_by_id(session.user_id, &mut db)
.await
.ok_or(Status::NotFound)
.inspect_err(|_| {
tracing::error!(
"Valid session does not have a valid user. ID: {}",
session.user_id
)
})?;
user.verify_password(&data.password)?;
if user.twofa_enabled {
user.verify_2fa(data.totp_code.as_deref().unwrap_or(""))?;
}
user.delete(&mut db)
.await
.inspect_err(|e| tracing::error!("{e}"))
.map_err(|_| Status::InternalServerError)?;
Ok(())
}
#[patch("/settings/display_name", data = "<new>")]
pub async fn change_display_name(
session: Session,
mut db: Connection<Postgres>,
new: Json<DisplayNameForm>,
) -> Result<(), Status> {
let mut user = User::get_by_id(session.user_id, &mut db)
.await
.ok_or(Status::NotFound)
.inspect_err(|_| {
tracing::error!(
"Valid session does not have a valid user. ID: {}",
session.user_id
)
})?;
user.set_display_name(new.display_name.clone(), &mut db)
.await
.inspect_err(|e| tracing::error!("{e}"))
.map_err(|_| Status::InternalServerError)?;
Ok(())
}
#[derive(Deserialize)]
pub struct UsernameForm {
username: String,
}
#[patch("/settings/username", data = "<new>")]
pub async fn change_username(
session: Session,
mut db: Connection<Postgres>,
new: Json<UsernameForm>,
) -> Result<(), Status> {
let mut user = User::get_by_id(session.user_id, &mut db)
.await
.ok_or(Status::NotFound)
.inspect_err(|_| {
tracing::error!(
"Valid session does not have a valid user. ID: {}",
session.user_id
)
})?;
user.set_username(new.username.clone(), &mut db)
.await
.inspect_err(|e| tracing::error!("{e}"))
.map_err(|_| Status::InternalServerError)?;
Ok(())
}
-109
View File
@@ -1,109 +0,0 @@
use std::{
sync::LazyLock,
time::{SystemTime, UNIX_EPOCH},
};
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation, decode, encode};
use rand::Rng;
use rocket::{
Request,
http::Status,
request::{self, FromRequest, Outcome},
};
use rocket_db_pools::Connection;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256, digest::block_buffer::Lazy};
use sqlx::postgres::PgQueryResult;
use crate::db::Postgres;
static JWT_SECRET: LazyLock<String> = LazyLock::new(|| std::env::var("JWT_SECRET").unwrap());
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Copy)]
#[serde(rename_all = "snake_case")]
pub enum TokenScope {
Full,
TotpPending,
}
pub struct Session {
pub user_id: usize,
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for Session {
type Error = ();
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match Claims::from_request(req).await {
Outcome::Success(user) if user.scope == TokenScope::Full => Outcome::Success(Session {
user_id: user.sub as usize,
}),
Outcome::Success(_) => {
eprintln!("warning: user with scope other than Full attempted to access session");
Outcome::Error((Status::Forbidden, ()))
}
Outcome::Error(err) => {
eprintln!("Session request guard failed: {:?}", err);
Outcome::Error(err)
}
_ => unreachable!("forward should never be called"),
}
}
}
#[derive(Serialize, Deserialize)]
pub struct Claims {
pub sub: i32,
pub exp: usize,
pub scope: TokenScope,
}
impl Claims {
pub fn new(user_id: usize, scope: TokenScope) -> Self {
Self {
sub: user_id as i32,
exp: (SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ 3600) as usize,
scope,
}
}
pub fn encode(&self) -> String {
encode(
&Header::default(),
self,
&EncodingKey::from_secret(JWT_SECRET.as_bytes()),
)
.expect("unable to encode jwt")
}
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for Claims {
type Error = ();
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let token = req
.headers()
.get_one("Authorization")
.and_then(|v| v.strip_prefix("Bearer "));
match token {
None => Outcome::Error((Status::Unauthorized, ())),
Some(t) => {
match decode::<Claims>(
t,
&DecodingKey::from_secret(JWT_SECRET.as_bytes()),
&Validation::default(),
) {
Ok(data) => Outcome::Success(data.claims),
Err(_) => Outcome::Error((Status::Unauthorized, ())),
}
}
}
}
}
-301
View File
@@ -1,301 +0,0 @@
use futures_util::TryFutureExt;
use rocket::{
Request,
http::Status,
outcome::{Outcome, try_outcome},
request::{self, FromRequest},
response::status::{self},
serde::json::Json,
};
use rocket_db_pools::Connection;
use rocket_dyn_templates::{Template, context};
use serde::{Deserialize, Serialize};
use totp_rs::{Algorithm, Secret, TOTP};
use crate::{
auth::{
account::AuthResponse,
profile::PasswordAnd2fa,
session::{Claims, Session, TokenScope},
},
db::Postgres,
user::User,
};
// Utility methods
pub fn totp_gen(user_id: usize, secret: &[u8]) -> Result<TOTP, String> {
TOTP::new(
Algorithm::SHA1,
6,
1,
30,
secret.to_owned(),
Some("chat.zxq5.dev".to_string()),
format!("{}", user_id),
)
.map_err(|_| String::from("Invalid Secret"))
}
// pages
#[get("/totp")]
pub async fn mfa_page(_session: Session) -> Template {
Template::render("2fa", context!())
}
#[post("/totp", data = "<form>")]
pub async fn confirm_totp(
mfa: TOTPSecret,
form: Json<TOTPSixDigitCode>,
mut db: Connection<Postgres>,
) -> Result<(), status::Custom<&'static str>> {
if form.code.len() != 6 || form.code.parse::<u32>().is_err() {
return Err(status::Custom(Status::BadRequest, "Invalid 6-digit code"));
}
println!("valid");
let totp = totp_gen(mfa.user_id, mfa.secret.as_bytes())
.map_err(|_| status::Custom(Status::InternalServerError, "TOTP Error"))?;
if !totp.check_current(&form.code).unwrap_or(false) {
return Err(status::Custom(Status::BadRequest, "Incorrect code"));
}
println!("correct");
if sqlx::query!(
"UPDATE users SET twofa_enabled = true WHERE id = $1",
mfa.user_id as i32
)
.execute(&mut **db)
.await
.is_err()
{
return Err(status::Custom(
Status::InternalServerError,
"unable to enable 2fa",
));
};
println!("enabled");
Ok(())
}
#[derive(Deserialize)]
pub struct PasswordConfirmation {
password: String,
}
#[post("/totp.jpg", data = "<form>")]
pub async fn get_totp(
mfa: TOTPSecret,
form: Json<PasswordConfirmation>,
) -> Option<Json<QrResponse>> {
let qr_b64 = totp_gen(mfa.user_id, mfa.secret.as_bytes())
.expect("Invalid TOTP")
.get_qr_base64()
.unwrap();
Some(Json(QrResponse {
qr_code: format!("data:image/png;base64,{}", qr_b64),
}))
}
#[derive(Debug, Deserialize)]
pub struct TOTPSixDigitCode {
code: String,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TotpStatus {
Enabled,
Disabled,
}
pub struct TOTPSecret {
user_id: usize,
secret: String,
}
#[derive(Serialize)]
pub struct QrResponse {
qr_code: String,
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for TOTPSecret {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
let auth_header = request.headers().get_one("Authorization");
println!(
"TOTPSecret guard - Auth header present: {}",
auth_header.is_some()
);
let user = try_outcome!(request.guard::<Claims>().await);
println!(
"TOTPSecret guard - Claims ok, user: {}, scope: {:?}",
user.sub, user.scope
);
// only allow full tokens for TOTP setup
if user.scope != TokenScope::Full {
println!("TOTPSecret guard - rejected, scope is {:?}", user.scope);
return Outcome::Error((Status::Forbidden, ()));
}
let user = try_outcome!(request.guard::<Session>().await);
let mut pool = match request.guard::<Connection<Postgres>>().await {
Outcome::Success(pool) => pool,
_ => return Outcome::Error((Status::Unauthorized, ())),
};
let row = sqlx::query!(
"SELECT twofa_enabled, totp_secret FROM users WHERE id = $1",
user.user_id as i32
)
.fetch_one(&mut **pool)
.await;
let (enabled, mut secret) = match row {
Ok(r) => (r.twofa_enabled, r.totp_secret),
Err(_) => return Outcome::Error((Status::Unauthorized, ())),
};
if secret.is_none() {
let new_secret = Secret::generate_secret().to_encoded().to_string();
sqlx::query!(
"UPDATE users SET totp_secret = $1 WHERE id = $2",
new_secret,
user.user_id as i32
)
.execute(&mut **pool)
.await
.ok();
secret = Some(new_secret);
}
Outcome::Success(TOTPSecret {
user_id: user.user_id,
secret: secret.unwrap(),
})
}
}
impl TOTPSecret {
pub async fn enable(&self, db: &mut Connection<Postgres>) -> Result<(), ()> {
match sqlx::query!(
"UPDATE users SET twofa_enabled = true WHERE id = $1",
self.user_id as i32,
)
.execute(&mut ***db)
.await
{
Ok(_) => Ok(()),
Err(_) => Err(()),
}
}
}
#[derive(Deserialize)]
pub struct TotpVerifyRequest {
pub code: String,
}
#[get("/totp/status")]
pub async fn get_totp_status(
user: Session,
mut db: Connection<Postgres>,
) -> Result<Json<TotpStatus>, Status> {
Ok(Json(
if sqlx::query!(
"SELECT twofa_enabled FROM users WHERE id = $1",
user.user_id as i32,
)
.fetch_one(&mut **db)
.await
.map_err(|_| Status::NotFound)?
.twofa_enabled
{
TotpStatus::Enabled
} else {
TotpStatus::Disabled
},
))
}
#[delete("/totp", data = "<form>")]
pub async fn disable_totp(
user: Session,
mut db: Connection<Postgres>,
form: Json<PasswordAnd2fa>,
) -> Result<Json<AuthResponse>, Status> {
let totp_code = form.totp_code.clone().ok_or(Status::BadRequest)?;
let mut user = User::get_by_id(user.user_id, &mut db)
.await
.ok_or(Status::NotFound)?;
user.verify_password(&form.password)?;
user.verify_2fa(&totp_code)?;
user.set_twofa_enabled(false, &mut db)
.await
.map_err(|_| Status::InternalServerError)?;
Ok(Json(AuthResponse {
token: Claims::new(user.id as usize, TokenScope::Full).encode(),
totp_required: false,
}))
}
#[post("/totp/verify", data = "<body>")]
pub async fn verify_totp(
claims: Claims, // request guard checks token validity
mut db: Connection<Postgres>,
body: Json<TotpVerifyRequest>,
) -> Result<Json<AuthResponse>, Status> {
println!("reached 1");
// reject if they somehow got here with a full token
if claims.scope != TokenScope::TotpPending {
return Err(Status::Forbidden);
}
println!("reached 2");
let row = sqlx::query!(
"SELECT totp_secret FROM users WHERE id = $1 AND twofa_enabled = TRUE",
claims.sub
)
.fetch_one(&mut **db)
.await
.map_err(|_| Status::Unauthorized)?;
println!("reached 3");
let totp = totp_gen(
claims.sub as usize,
row.totp_secret
.expect("user with 2fa enabled has no totp secret")
.as_bytes(),
)
.map_err(|_| Status::InternalServerError)?;
if !totp
.check_current(&body.code)
.map_err(|_| Status::InternalServerError)?
{
return Err(Status::Unauthorized);
}
println!("reached 5");
let claims = Claims::new(claims.sub as usize, TokenScope::Full);
Ok(Json(AuthResponse {
token: claims.encode(),
totp_required: false,
}))
}
+96
View File
@@ -0,0 +1,96 @@
use clap::{Parser, Subcommand};
use sqlx::postgres::PgPoolOptions;
use std::time::Duration;
use std::sync::Arc;
use crate::repo::user_repo::UserRepository;
use crate::repo::space_repo::SpaceRepository;
use crate::repo::channel_repo::ChannelRepository;
use crate::repo::{UserRepo, SpaceRepo, ChannelRepo};
use argon2::{
password_hash::{PasswordHasher, SaltString},
Argon2,
};
use rand::rngs::OsRng;
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
pub struct Cli {
#[command(subcommand)]
pub command: Option<Commands>,
}
#[derive(Subcommand)]
pub enum Commands {
/// First-time setup for the server
Setup {
/// Admin username
#[arg(short, long)]
username: String,
/// Admin password
#[arg(short, long)]
password: String,
/// Default space name
#[arg(short, long, default_value = "Default Space")]
space: String,
/// Default channel name
#[arg(short, long, default_value = "general")]
channel: String,
},
}
pub async fn handle_cli() -> bool {
let cli = Cli::parse();
match cli.command {
Some(Commands::Setup { username, password, space, channel }) => {
if let Err(e) = run_setup(username, password, space, channel).await {
eprintln!("Setup failed: {}", e);
std::process::exit(1);
}
println!("Setup completed successfully!");
true
}
None => false,
}
}
async fn run_setup(username: String, password: String, space_name: String, channel_name: String) -> Result<(), Box<dyn std::error::Error>> {
dotenv::dotenv().ok();
let db_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
let pool = PgPoolOptions::new()
.max_connections(1)
.acquire_timeout(Duration::from_secs(5))
.connect(&db_url)
.await?;
let user_repo = UserRepository::new(pool.clone());
let space_repo = SpaceRepository::new(pool.clone());
let channel_repo = ChannelRepository::new(pool.clone());
// 1. Create admin user
println!("Creating admin user: {}...", username);
let argon2 = Argon2::default();
let salt = SaltString::generate(&mut OsRng);
let passhash = argon2
.hash_password(password.as_bytes(), &salt)
.map_err(|e| e.to_string())?
.to_string();
let user_id = user_repo.new_user("admin@localhost", &username, &passhash).await?;
user_repo.set_role(user_id, "admin").await?;
// 2. Create default space
println!("Creating default space: {}...", space_name);
let space_id = space_repo.create(&space_name, Some("Default space created during setup"), user_id).await?;
// 3. Create default channel
println!("Creating default channel: {}...", channel_name);
channel_repo.create(&channel_name, Some("Default channel"), space_id).await?;
Ok(())
}
-9
View File
@@ -1,9 +0,0 @@
use rocket_db_pools::{Database, deadpool_redis};
#[derive(Database)]
#[database("postgres_db")]
pub struct Postgres(sqlx::PgPool);
#[derive(Database)]
#[database("redis_cache")]
pub struct Redis(deadpool_redis::Pool);
+127
View File
@@ -0,0 +1,127 @@
// error.rs
use rocket::{http::Status, response::{self, Responder}, Request, Response};
use thiserror::Error;
use rocket_dyn_templates::Template;
use rocket::serde::Serialize;
#[derive(Error, Debug)]
pub enum AppError {
#[error("Not found")]
NotFound,
#[error("Unauthorized")]
Unauthorised(String),
#[error("Forbidden")]
Forbidden,
#[error("Bad request: {0}")]
BadRequest(String),
#[error("Database error: {0}")]
Database(#[from] sqlx::Error),
#[error("Internal error: {0}")]
Internal(String),
}
impl AppError {
pub fn internal(msg: impl Into<String>) -> Self {
Self::Internal(msg.into())
}
pub fn bad_request(msg: impl Into<String>) -> Self {
Self::BadRequest(msg.into())
}
pub fn unauthorised(msg: impl Into<String>) -> Self {
Self::Unauthorised(msg.into())
}
}
impl<'r> Responder<'r, 'static> for AppError {
fn respond_to(self, _req: &'r Request<'_>) -> response::Result<'static> {
let status = match &self {
AppError::NotFound => Status::NotFound,
AppError::Unauthorised(_) => Status::Unauthorized,
AppError::Forbidden => Status::Forbidden,
AppError::BadRequest(_) => Status::BadRequest,
AppError::Database(_) => Status::InternalServerError,
AppError::Internal(_) => Status::InternalServerError,
};
// log internal errors
if status == Status::InternalServerError {
tracing::error!("Internal Server Error: {}", self);
}
Response::build()
.status(status)
.header(rocket::http::ContentType::Plain)
.sized_body(
self.to_string().len(),
std::io::Cursor::new(self.to_string())
)
.ok()
}
}
pub type ApiResult<T> = Result<T, AppError>;
#[derive(Serialize)]
struct ErrorContext {
error_code: u16,
error_message: &'static str,
additional_info: &'static str,
redirect: Option<RedirectContext>,
}
#[derive(Serialize)]
struct RedirectContext {
url: &'static str,
message: &'static str,
}
#[catch(404)]
pub async fn handle_404() -> Template {
Template::render(
"error",
ErrorContext {
error_code: 404,
error_message: "Not Found",
additional_info: "There's nothing here.",
redirect: Some(RedirectContext {
url: "/",
message: "Home",
}),
},
)
}
#[catch(401)]
pub async fn handle_401() -> Template {
Template::render(
"error",
ErrorContext {
error_code: 401,
error_message: "Unauthorised",
additional_info: "You are not authorised to access this resource.",
redirect: Some(RedirectContext {
url: "/login",
message: "Login",
}),
},
)
}
#[catch(default)]
pub async fn handle_default(status: Status, _request: &Request<'_>) -> Template {
Template::render(
"error",
ErrorContext {
error_code: status.code,
error_message: "Unknown Error",
additional_info: "I don't know what to do with this error.",
redirect: None,
},
)
}
-62
View File
@@ -1,62 +0,0 @@
use rocket::{Request, http::Status};
use rocket_dyn_templates::Template;
use serde::Serialize;
#[derive(Serialize)]
struct ErrorContext {
error_code: u16,
error_message: &'static str,
additional_info: &'static str,
redirect: Option<RedirectContext>,
}
#[derive(Serialize)]
struct RedirectContext {
url: &'static str,
message: &'static str,
}
#[catch(404)]
pub async fn handle_404() -> Template {
Template::render(
"error",
ErrorContext {
error_code: 404,
error_message: "Not Found",
additional_info: "There's nothing here.",
redirect: Some(RedirectContext {
url: "/",
message: "Home",
}),
},
)
}
#[catch(401)]
pub async fn handle_401() -> Template {
Template::render(
"error",
ErrorContext {
error_code: 401,
error_message: "Unauthorised",
additional_info: "You are not authorised to access this resource.",
redirect: Some(RedirectContext {
url: "/login",
message: "Login",
}),
},
)
}
#[catch(default)]
pub async fn handle_default(status: Status, _request: &Request<'_>) -> Template {
Template::render(
"error",
ErrorContext {
error_code: status.code,
error_message: "Unknown Error",
additional_info: "I don't know what to do with this error.",
redirect: None,
},
)
}
+143
View File
@@ -0,0 +1,143 @@
#![deny(clippy::unwrap_used)]
#![warn(clippy::all, clippy::nursery, clippy::cargo, clippy::pedantic)]
#[macro_use]
extern crate rocket;
pub mod api;
pub mod cli;
pub mod error;
pub mod messenger;
pub mod model;
pub mod repo;
pub mod svc;
use crate::repo::channel_repo::ChannelRepository;
use crate::repo::message_repo::MessageRepository;
use crate::repo::space_repo::SpaceRepository;
use crate::repo::user_repo::UserRepository;
use crate::repo::{Repo, access_token_repo::AccessTokenRepo};
use crate::svc::access_token_svc::AccessTokenService;
use crate::svc::auth_svc::AuthService;
use crate::svc::chat_svc::ChatService;
use crate::svc::llm_service::LlmService;
use crate::svc::settings_svc::SettingsService;
use crate::svc::user_svc::UserService;
use api::cdn;
use rocket::http::Method;
use rocket_cors::{AllowedOrigins, CorsOptions};
use sqlx::postgres::PgPoolOptions;
use std::sync::Arc;
use std::time::Duration;
pub fn rocket() -> rocket::Rocket<rocket::Build> {
if let Ok(var) = std::env::var("RELEASE_MODE") && var == "1" {
} else {
dotenv::dotenv().expect("Failed to load .env file");
}
let db_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
println!("Running with database URL: {}", db_url);
let pool = PgPoolOptions::new()
.max_connections(25)
.min_connections(5)
.acquire_timeout(Duration::from_secs(5))
.connect_lazy(&db_url)
.expect("Failed to create database pool");
let user_repo = Arc::new(UserRepository::new(pool.clone()));
let message_repo = MessageRepository::new(pool.clone());
let token_repo = Arc::new(AccessTokenRepo::new(pool.clone()));
let space_repo: Arc<dyn repo::SpaceRepo> = Arc::new(SpaceRepository::new(pool.clone()));
let channel_repo: Arc<dyn repo::ChannelRepo> = Arc::new(ChannelRepository::new(pool.clone()));
let llm_service = LlmService::new();
let chat_service = ChatService::new(
32,
llm_service.clone(),
message_repo.clone(),
user_repo.clone(),
channel_repo.clone(),
space_repo.clone(),
);
rocket_builder(
user_repo,
token_repo,
space_repo,
channel_repo,
chat_service,
)
}
pub fn rocket_builder(
user_repo: Arc<dyn repo::UserRepo>,
token_repo: Arc<dyn repo::AccessTokenRepoTrait>,
space_repo: Arc<dyn repo::SpaceRepo>,
channel_repo: Arc<dyn repo::ChannelRepo>,
chat_service: ChatService,
) -> rocket::Rocket<rocket::Build> {
let cors = CorsOptions::default()
.allowed_origins(AllowedOrigins::all())
.allowed_methods(
vec![Method::Get, Method::Post, Method::Patch]
.into_iter()
.map(From::from)
.collect(),
)
.allow_credentials(true)
.to_cors()
.expect("unable to create cors");
let access_token_svc = AccessTokenService::new(token_repo.clone());
let auth_service = AuthService::new(user_repo.clone(), access_token_svc.clone());
let settings_service = SettingsService::new(auth_service.clone(), user_repo.clone());
let user_service = UserService::new(user_repo.clone());
rocket::build()
.manage(chat_service)
.manage(auth_service)
.manage(settings_service)
.manage(access_token_svc)
.manage(user_service)
.manage(space_repo)
.manage(channel_repo)
.attach(cors)
.mount(
"/api",
routes![
cdn::upload_profile_pic,
api::profile::display_name,
// basic auth
api::auth::login,
api::auth::signup,
api::auth::generate_invite,
// 2fa
api::totp::confirm_totp,
api::totp::disable_totp,
api::totp::get_totp,
api::totp::get_totp_status,
api::totp::verify_totp,
// chat
api::chat::event_stream,
api::chat::post_message,
// user settings
api::settings::change_display_name,
api::settings::change_password,
api::settings::change_username,
api::settings::delete_account,
// spaces
api::space::list_spaces,
api::space::list_channels,
api::space::get_accessible_channels
],
)
// .register(
// "/",
// catchers![error::handle_401, error::handle_404, error::handle_default,],
// )
}
-69
View File
@@ -1,69 +0,0 @@
// src/llm.rs
use serde::{Deserialize, Serialize};
use crate::messenger::ChatMsg;
#[derive(Serialize)]
struct LlmRequest {
model: String,
messages: Vec<Message>,
}
#[derive(Serialize, Deserialize)]
struct Message {
role: String, // "user" or "assistant"
content: String,
}
pub struct LlmWorker {
uri: String,
}
impl LlmWorker {
pub fn new(uri: String) -> Self {
Self { uri }
}
pub async fn query(&self, message: &ChatMsg) -> Result<ChatMsg, String> {
let client = reqwest::Client::new();
// Build the request body
let payload = LlmRequest {
model: "gpt-oss-20b".into(), // whatever model you run locally
messages: vec![Message {
role: "user".into(),
content: message.text.clone(),
}],
};
// POST to lmstudio (default 127.0.0.1:1234)
let resp = client
.post(self.uri.clone())
.json(&payload)
.send()
.await
.map_err(|_| String::from("Failed to make request to LLM server"))?;
// The API returns a JSON with `choices[].message.content`
#[derive(Deserialize)]
struct LlmResponse {
choices: Vec<Choice>,
}
#[derive(Deserialize)]
struct Choice {
message: Message,
}
let llm_resp: LlmResponse = resp
.json()
.await
.map_err(|_| String::from("Failed to make request to LLM server"))?;
Ok(ChatMsg {
display_name: Some(String::from("lmstudio")),
user_id: 0,
text: llm_resp.choices[0].message.content.clone(),
timestamp: chrono::Utc::now().timestamp_millis() as usize,
})
}
}
+9 -96
View File
@@ -1,98 +1,11 @@
// src/main.rs use backend::rocket;
#[macro_use] use backend::cli::handle_cli;
extern crate rocket;
use rocket::fs::{FileServer, NamedFile}; #[rocket::main]
use rocket::http::Method; async fn main() -> Result<(), rocket::Error> {
use rocket::{Build, Rocket}; if handle_cli().await {
use rocket_cors::{AllowedOrigins, CorsOptions}; return Ok(());
use rocket_db_pools::Database; }
use rocket_dyn_templates::Template; rocket().launch().await?;
use std::env; Ok(())
use std::sync::{Arc, LazyLock};
use crate::db::{Postgres, Redis};
pub mod auth;
pub mod cdn;
pub mod db;
pub mod handlers;
pub mod llm;
pub mod messenger;
pub mod user;
static LMSTUDIO_URL: LazyLock<String> =
LazyLock::new(|| env::var("LMSTUDIO_URL").expect("Ensure LMSTUDIO_URL is set!"));
#[launch]
fn rocket() -> Rocket<Build> {
// make sure the env is loaded
dotenv::dotenv().expect("Failed to load env! aborting launch!");
let chat = Arc::new(crate::messenger::ChatBroadcaster::new(32));
let cors = CorsOptions::default()
.allowed_origins(AllowedOrigins::all())
.allowed_methods(
vec![Method::Get, Method::Post, Method::Patch]
.into_iter()
.map(From::from)
.collect(),
)
.allow_credentials(true);
rocket::build()
.manage(chat)
.attach(cors.to_cors().unwrap())
.attach(Postgres::init())
.attach(Redis::init())
.attach(Template::fairing())
.mount("/static", FileServer::from("static"))
.mount("/cdn", cdn::routes())
.mount(
"/",
routes![
favicon,
messenger::chat_page,
auth::signup_page,
auth::login_page,
auth::mfa_page,
auth::invite_page,
],
)
.mount(
"/api",
routes![
cdn::upload_profile_pic,
messenger::post_message,
messenger::event_stream,
user::users,
user::display_name,
auth::signup,
auth::login,
auth::get_totp,
auth::confirm_totp,
auth::generate_invite,
auth::verify_totp,
auth::disable_totp,
auth::get_totp_status,
auth::change_password,
auth::change_display_name,
auth::change_username,
auth::delete_account,
],
)
.register(
"/",
catchers![
handlers::handle_404,
handlers::handle_401,
handlers::handle_default
],
)
}
#[get("/favicon.ico")]
async fn favicon() -> NamedFile {
NamedFile::open("static/favicon.ico").await.unwrap()
} }
+2 -4
View File
@@ -1,10 +1,8 @@
use redis::AsyncCommands; use redis::AsyncCommands;
use rocket_db_pools::Connection; use rocket_db_pools::Connection;
use crate::{ use crate::api::chat::ChatMsg;
db::{Postgres, Redis}, use crate::db::{Postgres, Redis};
messenger::ChatMsg,
};
// Helper function to cache message in Redis // Helper function to cache message in Redis
pub async fn insert( pub async fn insert(
-220
View File
@@ -1,220 +0,0 @@
use std::sync::Arc;
use rocket::{
Shutdown,
response::stream::{Event, EventStream},
serde::json::Json,
time::OffsetDateTime,
};
use rocket_db_pools::Connection;
use rocket_dyn_templates::{Template, context};
use serde::{Deserialize, Serialize};
use sqlx::prelude::FromRow;
use tokio::{select, sync::broadcast};
use crate::{
auth::Session,
db::{Postgres, Redis},
llm::LlmWorker,
messenger,
};
/// ---------- shared broadcaster ----------
pub struct ChatBroadcaster {
buffer_size: usize,
senders: std::sync::Mutex<std::collections::HashMap<i32, broadcast::Sender<ChatMsg>>>,
}
impl ChatBroadcaster {
pub fn new(buffer_size: usize) -> Self {
Self {
buffer_size,
senders: std::sync::Mutex::new(std::collections::HashMap::new()),
}
}
/// Publish a message to the specified channel.
pub async fn publish(&self, channel_id: i32, msg: ChatMsg) {
let mut map = self.senders.lock().unwrap();
let sender = map
.entry(channel_id)
.or_insert_with(|| broadcast::channel::<ChatMsg>(self.buffer_size).0);
let _ = sender.send(msg);
}
/// Subscribe to the specified channel.
pub fn subscribe(&self, channel_id: i32) -> broadcast::Receiver<ChatMsg> {
let mut map = self.senders.lock().unwrap();
let sender = map
.entry(channel_id)
.or_insert_with(|| broadcast::channel::<ChatMsg>(self.buffer_size).0);
sender.subscribe()
}
}
/// ---------- Rocket routes ----------
#[derive(Debug, Serialize, Deserialize, Clone, FromRow)]
pub struct ChatMsg {
pub display_name: Option<String>,
pub user_id: usize,
pub text: String,
pub timestamp: usize,
}
#[post("/chat/<channel_id>", format = "json", data = "<msg>")]
pub async fn post_message(
mut msg: Json<ChatMsg>,
chat: &rocket::State<Arc<ChatBroadcaster>>,
mut postgres: Connection<Postgres>,
mut cache: Option<Connection<Redis>>,
session: Session,
channel_id: i32,
) -> Result<(), String> {
let chat = chat.inner().clone();
let display_name = sqlx::query!(
"SELECT display_name, username FROM users WHERE id = $1",
session.user_id as i32
)
.fetch_one(&mut **postgres)
.await
.map(|row| row.display_name.unwrap_or(row.username))
.unwrap_or_else(|_| "Unknown".to_string());
msg.user_id = session.user_id;
msg.display_name = Some(display_name);
chat.publish(channel_id, msg.clone().into_inner()).await;
sqlx::query!(
"INSERT INTO messages (channel_id, user_id, content, created_at) VALUES ($1, $2, $3, $4)",
channel_id,
msg.user_id as i32,
msg.text,
OffsetDateTime::from_unix_timestamp_nanos(msg.timestamp as i128 * 1_000_000).unwrap()
)
.execute(&mut **postgres)
.await
.map_err(|_| "Failed".to_string())?;
println!("gisfujdeghnjuisdfjngiosdfgjkosdf gnojdfsg nmodfsg");
if let Some(ref mut cache) = cache {
messenger::cache::insert(cache, channel_id, &msg)
.await
.map_err(|_| "Redis cache failed".to_string())?;
}
// get response
tokio::spawn(async move {
let response = LlmWorker::new(crate::LMSTUDIO_URL.to_string())
.query(&msg)
.await;
if let Ok(reply) = response {
chat.publish(channel_id, reply.clone()).await;
if let Some(ref mut cache) = cache {
messenger::cache::insert(cache, channel_id, &reply)
.await
.ok();
}
sqlx::query!(
"INSERT INTO messages (channel_id, user_id, content, created_at) VALUES ($1, $2, $3, $4)",
channel_id,
reply.user_id as i32,
reply.text,
OffsetDateTime::from_unix_timestamp_nanos(reply.timestamp as i128 * 1_000_000).unwrap()
)
.execute(&mut **postgres)
.await
.map_err(|_| "Failed".to_string())
.unwrap();
}
});
Ok(())
}
pub async fn get_messages(
mut db: Connection<Postgres>,
mut redis: Connection<Redis>,
channel_id: i32,
) -> Json<Vec<ChatMsg>> {
if let Ok(messages) = messenger::cache::get(&mut redis, channel_id).await
&& !messages.is_empty()
{
return Json(messages);
};
if let Err(x) = messenger::cache::initialise(&mut redis, &mut db, channel_id).await {
eprintln!("WARN: {x:?}");
}
if let Ok(messages) = messenger::cache::get(&mut redis, channel_id).await
&& !messages.is_empty()
{
return Json(messages);
};
let res = sqlx::query!(
"SELECT u.username, u.display_name, u.id, m.content, m.created_at
FROM messages m
JOIN users u ON m.user_id = u.id
WHERE m.channel_id = $1
ORDER BY m.created_at DESC LIMIT 100",
channel_id
)
.fetch_all(&mut **db)
.await
.unwrap_or_else(|_| Vec::new())
.into_iter()
.rev()
.map(|msg| ChatMsg {
display_name: Some(msg.display_name.unwrap_or(msg.username)),
user_id: msg.id as usize,
text: msg.content,
timestamp: (msg.created_at.unwrap().unix_timestamp_nanos() / 1_000_000) as usize,
})
.collect();
Json(res)
}
#[get("/events/<channel_id>")]
pub async fn event_stream(
chat: &rocket::State<Arc<ChatBroadcaster>>,
postgres: Connection<Postgres>,
cache: Connection<Redis>,
_session: Session,
mut shutdown: Shutdown,
channel_id: i32,
) -> EventStream![] {
let mut rx = chat.subscribe(channel_id);
EventStream! {
// Initialize the stream with the last 100 messages
for msg in get_messages(postgres, cache, channel_id).await.0 {
yield Event::json(&msg);
}
loop {
select!{
// exit early on shutdown
_ = &mut shutdown => break,
msg = rx.recv() => match msg {
Ok(msg) => yield Event::json(&msg),
Err(broadcast::error::RecvError::Lagged(_)) => {
yield Event::comment("RecvError::Lagged");
}
Err(broadcast::error::RecvError::Closed) => break,
},
}
}
}
}
#[get("/chat")]
pub async fn chat_page(session: Session) -> Template {
Template::render("chat", context!(user_id: session.user_id))
}
+1 -4
View File
@@ -1,4 +1 @@
mod cache; // mod cache;
mod messages;
pub use messages::{ChatBroadcaster, ChatMsg, chat_page, event_stream, get_messages, post_message};
+35
View File
@@ -0,0 +1,35 @@
use rocket::serde::{Deserialize, Serialize};
use chrono::{DateTime, Utc};
#[derive(Serialize, Deserialize)]
pub struct SignupCredentials {
pub email: String,
pub username: String,
pub password: String,
pub access_token: String,
}
#[derive(Serialize, Deserialize)]
pub struct LoginCredentials {
pub username: String,
pub password: String,
}
#[derive(Serialize, Deserialize)]
pub struct AuthResponse {
pub token: String,
pub totp_required: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessTokenForm {
pub name: String,
pub max_uses: i32,
pub expiry_date: DateTime<Utc>,
pub start_date: DateTime<Utc>,
}
pub struct AccessToken {
pub id: i64,
pub code: String,
}
+27
View File
@@ -0,0 +1,27 @@
use rocket::serde::{Deserialize, Serialize};
use sqlx::FromRow;
use chrono::{DateTime, Utc};
use uuid::Uuid;
/// ---------- Rocket routes ----------
#[derive(Debug, Serialize, Deserialize, Clone, FromRow)]
pub struct ChatMsg {
pub id: Uuid,
pub display_name: Option<String>,
pub user_id: i64,
pub text: String,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", content = "data")]
pub enum ChatEvent {
SendMessage(ChatMsg),
/// for when a user explicitly edits a message
EditMessage { id: Uuid, msg: ChatMsg },
/// used for streaming content to a message
/// will not show up as edited
MessageAppendContent{ id: Uuid, content: String }
}
+4
View File
@@ -0,0 +1,4 @@
pub mod auth;
pub mod user;
pub mod space;
pub mod event;
+35
View File
@@ -0,0 +1,35 @@
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
use chrono::{DateTime, Utc};
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct Space {
pub id: i64,
pub name: String,
pub description: Option<String>,
pub owner_id: i64,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct Channel {
pub id: i64,
pub name: String,
pub description: Option<String>,
pub space_id: i64,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpaceDto {
pub channels: Vec<Channel>,
pub id: i64,
pub owner_id: i64,
pub name: String,
pub description: Option<String>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
+61
View File
@@ -0,0 +1,61 @@
use crate::api::auth::Session;
use crate::error::ApiResult;
use crate::svc::user_svc::UserService;
use chrono::{DateTime, Utc};
use rocket::serde::{Deserialize, Serialize};
use rocket::State;
use sqlx::FromRow;
use crate::api::totp::TotpStatus;
#[derive(Clone)]
#[derive(FromRow)]
pub struct User {
pub id: i64,
pub email: Option<String>,
pub username: String,
pub nickname: Option<String>,
pub passhash: String,
pub totp_status: TotpStatus,
pub totp_secret: Option<String>,
pub created_at: Option<DateTime<Utc>>,
pub updated_at: Option<DateTime<Utc>>,
}
#[derive(Debug, sqlx::Type, Clone, Copy, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
#[sqlx(type_name = "user_role", rename_all = "lowercase")]
pub enum UserRole {
User,
Admin,
}
// pub struct UserCache {}
//
// impl UserCache {
// pub async fn username(
// id: usize,
// redis_conn: &mut Connection<Redis>,
// pgsql_conn: &mut Connection<Postgres>,
// ) -> String {
// if let Ok(val) = redis_conn.get(format!("users:{id}")).await {
// return val;
// }
//
// if let Ok(v) = sqlx::query!("SELECT username FROM users WHERE id = $1", id as i32)
// .fetch_one(&mut ***pgsql_conn)
// .await
// {
// let username = v.username;
// Self::insert(id, &username, redis_conn).await;
// username
// } else {
// unimplemented!()
// }
// }
//
// pub async fn insert(id: usize, username: &str, conn: &mut Connection<Redis>) {
// conn.set_ex::<_, _, ()>(format!("users:{id}"), username.to_string(), 1800)
// .await
// .expect("failed to insert key");
// }
// }
+67
View File
@@ -0,0 +1,67 @@
use crate::repo::{Repo, AccessTokenRepoTrait};
use chrono::{DateTime, Utc};
use sqlx::PgPool;
use crate::model::auth::AccessToken;
#[derive(Clone)]
pub struct AccessTokenRepo {
pool: PgPool
}
impl Repo for AccessTokenRepo {
type Target = AccessToken;
fn new(pool: PgPool) -> Self {
Self { pool }
}
async fn get_by_id(&self, id: i64) -> Option<Self::Target> {
sqlx::query_as!(AccessToken, "SELECT id, code FROM access_tokens WHERE id = $1", id)
.fetch_optional(&self.pool)
.await
.unwrap_or(None)
}
}
#[async_trait]
impl AccessTokenRepoTrait for AccessTokenRepo {
async fn get_by_id(&self, id: i64) -> Option<AccessToken> {
Repo::get_by_id(self, id).await
}
async fn create_new(&self,
uid: i64, name: &str, code: &str, max_uses: i32,
start_date: DateTime<Utc>, expiry_date: DateTime<Utc>
) -> Result<i64, sqlx::Error> {
sqlx::query!(
"INSERT INTO access_tokens (name, code, creator_id, max_uses, created_at, expires_at)
VALUES ($1, $2, $3, $4, $5, $6) RETURNING id",
name,
code,
uid,
max_uses,
start_date,
expiry_date
).fetch_optional(&self.pool).await.and_then(|row| row.map(|r| r.id).ok_or(sqlx::Error::RowNotFound))
}
async fn use_token(&self, id: i64) -> Result<(), sqlx::Error> {
sqlx::query!("UPDATE access_tokens SET uses = uses + 1 WHERE id = $1", id)
.execute(&self.pool)
.await?;
Ok(())
}
async fn get_code_not_expired(&self, code: &str) -> Result<Option<AccessToken>, sqlx::Error> {
sqlx::query_as!(AccessToken,
"SELECT id, code FROM access_tokens
WHERE code = $1
AND created_at < NOW()
AND expires_at > NOW()
AND uses < max_uses",
code
)
.fetch_optional(&self.pool)
.await
}
}
+39
View File
@@ -0,0 +1,39 @@
use crate::repo::ChannelRepo;
use crate::model::space::Channel;
use sqlx::PgPool;
#[derive(Clone)]
pub struct ChannelRepository {
pool: PgPool,
}
impl ChannelRepository {
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
}
#[rocket::async_trait]
impl ChannelRepo for ChannelRepository {
async fn create(&self, name: &str, description: Option<&str>, space_id: i64) -> Result<i64, sqlx::Error> {
let row = sqlx::query!(
"INSERT INTO channels (name, description, space_id) VALUES ($1, $2, $3) RETURNING id",
name,
description,
space_id
)
.fetch_one(&self.pool)
.await?;
Ok(row.id)
}
async fn get_by_space_id(&self, space_id: i64) -> Result<Vec<Channel>, sqlx::Error> {
sqlx::query_as!(
Channel,
"SELECT id, name, description, space_id, created_at as \"created_at!\", updated_at as \"updated_at!\" FROM channels WHERE space_id = $1",
space_id
)
.fetch_all(&self.pool)
.await
}
}

Some files were not shown because too many files have changed in this diff Show More