"""Test repository pattern implementations.""" import uuid import pytest from app.data.database import init_db, get_session from app.data.repositories import SessionRepository, NoteRepository, LinkRepository from app.data.models import Session, Note, Link class TestSessionRepository: """Test SessionRepository operations.""" @pytest.mark.asyncio async def test_create_session(self): """Test creating a new session.""" await init_db() async with get_session() as db: repo = SessionRepository(db) session = await repo.create(status="active") assert session.id is not None assert session.status == "active" assert session.transcript == [] @pytest.mark.asyncio async def test_get_session(self): """Test retrieving a session by ID.""" await init_db() # Create a session session_id = None async with get_session() as db: repo = SessionRepository(db) session = await repo.create(status="test") session_id = session.id # Retrieve the session async with get_session() as db: repo = SessionRepository(db) retrieved = await repo.get(session_id) assert retrieved is not None assert retrieved.id == session_id assert retrieved.status == "test" @pytest.mark.asyncio async def test_update_status(self): """Test updating session status.""" await init_db() # Create a session session_id = None async with get_session() as db: repo = SessionRepository(db) session = await repo.create(status="active") session_id = session.id # Update status async with get_session() as db: repo = SessionRepository(db) updated = await repo.update_status(session_id, "completed") assert updated is not None assert updated.status == "completed" @pytest.mark.asyncio async def test_append_transcript(self): """Test appending to session transcript.""" await init_db() # Create a session session_id = None async with get_session() as db: repo = SessionRepository(db) session = await repo.create() session_id = session.id # Append to transcript message = {"role": "user", "content": "Hello"} async with get_session() as db: repo = SessionRepository(db) updated = await repo.append_transcript(session_id, message) assert updated is not None assert len(updated.transcript) == 1 assert updated.transcript[0] == message class TestNoteRepository: """Test NoteRepository operations.""" @pytest.mark.asyncio async def test_create_bulk_notes(self): """Test creating multiple notes at once.""" await init_db() # Create a session first session_id = None async with get_session() as db: session_repo = SessionRepository(db) session = await session_repo.create() session_id = session.id # Create notes notes = [ Note( title="Note 1", content="Content 1", tags=["tag1"], session_id=session_id ), Note( title="Note 2", content="Content 2", tags=["tag2"], session_id=session_id ) ] async with get_session() as db: repo = NoteRepository(db) created_notes = await repo.create_bulk(notes) assert len(created_notes) == 2 assert all(note.id is not None for note in created_notes) assert created_notes[0].title == "Note 1" assert created_notes[1].title == "Note 2" @pytest.mark.asyncio async def test_get_by_session(self): """Test retrieving notes by session ID.""" await init_db() # Create session and notes session_id = None async with get_session() as db: session_repo = SessionRepository(db) session = await session_repo.create() session_id = session.id notes = [ Note(title="Note 1", content="Content 1", tags=["tag1"], session_id=session_id), Note(title="Note 2", content="Content 2", tags=["tag2"], session_id=session_id) ] note_repo = NoteRepository(db) await note_repo.create_bulk(notes) # Retrieve notes async with get_session() as db: repo = NoteRepository(db) session_notes = await repo.get_by_session(session_id) assert len(session_notes) == 2 titles = [note.title for note in session_notes] assert "Note 1" in titles assert "Note 2" in titles class TestLinkRepository: """Test LinkRepository operations.""" @pytest.mark.asyncio async def test_create_bulk_links(self): """Test creating multiple links at once.""" await init_db() # Create session and notes first note_ids = [] async with get_session() as db: session_repo = SessionRepository(db) session = await session_repo.create() notes = [ Note(title="Note 1", content="Content 1", tags=["tag1"], session_id=session.id), Note(title="Note 2", content="Content 2", tags=["tag2"], session_id=session.id) ] note_repo = NoteRepository(db) created_notes = await note_repo.create_bulk(notes) note_ids = [note.id for note in created_notes] # Create links links = [ Link( context="These notes are related because...", source_id=note_ids[0], target_id=note_ids[1] ) ] async with get_session() as db: repo = LinkRepository(db) created_links = await repo.create_bulk(links) assert len(created_links) == 1 assert created_links[0].id is not None assert created_links[0].source_id == note_ids[0] assert created_links[0].target_id == note_ids[1]