Files
app/tests/test_database.py
Albert f60d61a78f feat: implement data layer with comprehensive test infrastructure
- Define SQLModel schemas for Session, Note, and Link entities
  - Add API request/response models for RPC endpoints
  - Create LLM structured output models for Zettel extraction
  - Set up async database initialization with SQLModel and aiosqlite
  - Implement repository pattern for CRUD operations
  - Add complete test suite with pytest configuration
  - Create validation test runner for development workflow
  - Add .gitignore for Python/FastAPI project security
2025-08-17 01:25:16 +00:00

73 lines
2.5 KiB
Python

"""Test database initialization and basic operations."""
import pytest
import os
import tempfile
from sqlmodel import SQLModel
from app.data.database import init_db, get_session, engine
from app.data.models import Session, Note, Link
class TestDatabase:
"""Test database operations."""
@pytest.mark.asyncio
async def test_init_db(self):
"""Test database initialization creates tables."""
await init_db()
# Check that tables exist by attempting to create a session
async with get_session() as db:
session = Session(status="test")
db.add(session)
await db.commit()
await db.refresh(session)
assert session.id is not None
@pytest.mark.asyncio
async def test_session_context_manager(self):
"""Test database session context manager."""
await init_db()
# Test successful transaction
async with get_session() as db:
session = Session(status="test")
db.add(session)
# Context manager should auto-commit
# Verify data was saved
async with get_session() as db:
from sqlmodel import select
statement = select(Session).where(Session.status == "test")
result = await db.exec(statement)
found_session = result.first()
assert found_session is not None
assert found_session.status == "test"
@pytest.mark.asyncio
async def test_session_rollback_on_error(self):
"""Test that session rolls back on errors."""
await init_db()
initial_count = 0
async with get_session() as db:
from sqlmodel import select, func
result = await db.exec(select(func.count(Session.id)))
initial_count = result.one()
# Simulate an error during transaction
try:
async with get_session() as db:
session = Session(status="test_rollback")
db.add(session)
# Force an error
raise ValueError("Simulated error")
except ValueError:
pass
# Verify data was not saved due to rollback
async with get_session() as db:
from sqlmodel import select, func
result = await db.exec(select(func.count(Session.id)))
final_count = result.one()
assert final_count == initial_count # No new records