228 lines
7.2 KiB
Python
228 lines
7.2 KiB
Python
import argparse
|
|
import json
|
|
import threading
|
|
import time
|
|
from dataclasses import dataclass
|
|
from getpass import getpass
|
|
from typing import List
|
|
|
|
import requests
|
|
|
|
|
|
BASE_URL = "http://localhost:8000"
|
|
|
|
|
|
@dataclass
|
|
class AuthResult:
|
|
token: str
|
|
|
|
|
|
def signup(session: requests.Session, email: str, username: str, password: str, access_token: str) -> None:
|
|
url = f"{BASE_URL}/api/signup"
|
|
payload = {
|
|
"email": email,
|
|
"username": username,
|
|
"password": password,
|
|
"access_token": access_token,
|
|
}
|
|
resp = session.post(url, json=payload, timeout=10)
|
|
resp.raise_for_status()
|
|
|
|
|
|
def login(session: requests.Session, username: str, password: str) -> AuthResult:
|
|
url = f"{BASE_URL}/api/login"
|
|
payload = {
|
|
"username": username,
|
|
"password": password,
|
|
}
|
|
resp = session.post(url, json=payload, timeout=10)
|
|
resp.raise_for_status()
|
|
|
|
data = resp.json()
|
|
token = data.get("token")
|
|
if not token:
|
|
raise RuntimeError(f"Login response did not contain token: {data}")
|
|
return AuthResult(token=token)
|
|
|
|
|
|
def post_message(session: requests.Session, channel_id: int, token: str, text: str, display_name: str, user_id: int) -> None:
|
|
url = f"{BASE_URL}/api/chat/{channel_id}"
|
|
payload = {
|
|
"display_name": display_name,
|
|
"user_id": user_id,
|
|
"text": text,
|
|
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
|
}
|
|
headers = {"Authorization": f"Bearer {token}"}
|
|
resp = session.post(url, json=payload, headers=headers, timeout=10)
|
|
resp.raise_for_status()
|
|
|
|
|
|
def read_sse_messages(
|
|
session: requests.Session,
|
|
channel_id: int,
|
|
token: str,
|
|
expected_count: int,
|
|
timeout_s: int,
|
|
capture_live_messages: threading.Event,
|
|
) -> List[dict]:
|
|
url = f"{BASE_URL}/api/events/{channel_id}"
|
|
headers = {
|
|
"Authorization": f"Bearer {token}",
|
|
"Accept": "text/event-stream",
|
|
}
|
|
|
|
received: List[dict] = []
|
|
deadline = time.monotonic() + timeout_s
|
|
|
|
try:
|
|
with session.get(url, headers=headers, stream=True, timeout=(5, timeout_s)) as resp:
|
|
resp.raise_for_status()
|
|
|
|
event_data_lines: List[str] = []
|
|
|
|
for raw_line in resp.iter_lines(decode_unicode=True):
|
|
if time.monotonic() > deadline:
|
|
break
|
|
|
|
if raw_line is None:
|
|
continue
|
|
|
|
line = raw_line.strip()
|
|
|
|
if not line:
|
|
if event_data_lines:
|
|
joined = "\n".join(event_data_lines)
|
|
event_data_lines.clear()
|
|
try:
|
|
obj = json.loads(joined)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
if capture_live_messages.is_set():
|
|
received.append(obj)
|
|
if expected_count > 0 and len(received) >= expected_count:
|
|
break
|
|
else:
|
|
print(f"Discarding message: {obj}")
|
|
continue
|
|
|
|
if line.startswith("data:"):
|
|
event_data_lines.append(line[len("data:"):].strip())
|
|
|
|
except requests.exceptions.Timeout:
|
|
print("Timeout while reading SSE.")
|
|
except requests.exceptions.RequestException as exc:
|
|
print(f"Error reading SSE: {exc}")
|
|
|
|
return received
|
|
|
|
|
|
def prompt_nonempty(label: str, secret: bool = False) -> str:
|
|
while True:
|
|
value = getpass(label) if secret else input(label)
|
|
value = value.strip()
|
|
if value:
|
|
return value
|
|
print("Please enter a value.")
|
|
|
|
|
|
def main() -> int:
|
|
parser = argparse.ArgumentParser(description="Chat integration test against localhost:8000")
|
|
parser.add_argument("--existing-account", action="store_true",
|
|
help="Skip signup and only log in with an existing account")
|
|
parser.add_argument("--email", default=None)
|
|
parser.add_argument("--username", default=None)
|
|
parser.add_argument("--password", default=None)
|
|
parser.add_argument("--access-token", default=None,
|
|
help="Required only for signup mode")
|
|
parser.add_argument("--channel-id", type=int, default=1)
|
|
parser.add_argument("--message-count", type=int, default=5)
|
|
parser.add_argument("--timeout", type=int, default=15)
|
|
args = parser.parse_args()
|
|
|
|
session = requests.Session()
|
|
|
|
if args.existing_account:
|
|
username = args.username or prompt_nonempty("Username: ")
|
|
password = args.password or prompt_nonempty("Password: ", secret=True)
|
|
else:
|
|
email = args.email or prompt_nonempty("Email: ")
|
|
username = args.username or prompt_nonempty("Username: ")
|
|
password = args.password or prompt_nonempty("Password: ", secret=True)
|
|
access_token = args.access_token or prompt_nonempty("Access token: ")
|
|
|
|
print("[1/5] Signing up...")
|
|
try:
|
|
signup(session, email, username, password, access_token)
|
|
print(" signup ok")
|
|
except requests.HTTPError as e:
|
|
print(f" signup returned HTTP error: {e}")
|
|
print(" continuing to login...")
|
|
|
|
print("[2/5] Logging in...")
|
|
auth = login(session, username, password)
|
|
print(" login ok")
|
|
print(f" token: {auth.token[:12]}...")
|
|
|
|
print("[3/5] Opening event stream...")
|
|
received_messages: List[dict] = []
|
|
capture_live_messages = threading.Event()
|
|
stream_done = threading.Event()
|
|
|
|
def stream_reader() -> None:
|
|
nonlocal received_messages
|
|
try:
|
|
received_messages = read_sse_messages(
|
|
session=session,
|
|
channel_id=args.channel_id,
|
|
token=auth.token,
|
|
expected_count=args.message_count,
|
|
timeout_s=args.timeout,
|
|
capture_live_messages=capture_live_messages,
|
|
)
|
|
finally:
|
|
stream_done.set()
|
|
|
|
t = threading.Thread(target=stream_reader, daemon=True)
|
|
t.start()
|
|
|
|
# Give the server time to flush backlog on this same stream connection.
|
|
time.sleep(1.0)
|
|
|
|
print("[4/5] Starting to capture live messages and sending messages...")
|
|
capture_live_messages.set()
|
|
|
|
sent_texts = [f"Message {i}" for i in range(args.message_count)]
|
|
for i, text in enumerate(sent_texts):
|
|
post_message(
|
|
session=session,
|
|
channel_id=args.channel_id,
|
|
token=auth.token,
|
|
text=text,
|
|
display_name=username,
|
|
user_id=1,
|
|
)
|
|
print(f" sent {i + 1}/{args.message_count}: {text}")
|
|
time.sleep(0.1)
|
|
|
|
stream_done.wait(timeout=args.timeout)
|
|
t.join(timeout=1)
|
|
|
|
print("\nReceived messages:")
|
|
for i, msg in enumerate(received_messages, start=1):
|
|
print(f" {i}. {msg}")
|
|
|
|
received_texts = [m.get("text") for m in received_messages if isinstance(m, dict)]
|
|
|
|
for text in sent_texts:
|
|
if text not in received_texts:
|
|
print(f"\nFAIL: missing message: {text}")
|
|
return 1
|
|
|
|
print("\nPASS: login and message delivery test succeeded.")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main()) |