"""Test database initialization and basic operations.""" import pytest import os import tempfile from sqlmodel import SQLModel from app.data.database import init_db, get_session, engine from app.data.models import Session, Note, Link class TestDatabase: """Test database operations.""" @pytest.mark.asyncio async def test_init_db(self): """Test database initialization creates tables.""" await init_db() # Check that tables exist by attempting to create a session async with get_session() as db: session = Session(status="test") db.add(session) await db.commit() await db.refresh(session) assert session.id is not None @pytest.mark.asyncio async def test_session_context_manager(self): """Test database session context manager.""" await init_db() # Test successful transaction async with get_session() as db: session = Session(status="test") db.add(session) # Context manager should auto-commit # Verify data was saved async with get_session() as db: from sqlmodel import select statement = select(Session).where(Session.status == "test") result = await db.exec(statement) found_session = result.first() assert found_session is not None assert found_session.status == "test" @pytest.mark.asyncio async def test_session_rollback_on_error(self): """Test that session rolls back on errors.""" await init_db() initial_count = 0 async with get_session() as db: from sqlmodel import select, func result = await db.exec(select(func.count(Session.id))) initial_count = result.one() # Simulate an error during transaction try: async with get_session() as db: session = Session(status="test_rollback") db.add(session) # Force an error raise ValueError("Simulated error") except ValueError: pass # Verify data was not saved due to rollback async with get_session() as db: from sqlmodel import select, func result = await db.exec(select(func.count(Session.id))) final_count = result.one() assert final_count == initial_count # No new records