- Define SQLModel schemas for Session, Note, and Link entities - Add API request/response models for RPC endpoints - Create LLM structured output models for Zettel extraction - Set up async database initialization with SQLModel and aiosqlite - Implement repository pattern for CRUD operations - Add complete test suite with pytest configuration - Create validation test runner for development workflow - Add .gitignore for Python/FastAPI project security
73 lines
2.5 KiB
Python
73 lines
2.5 KiB
Python
"""Test database initialization and basic operations."""
|
|
import pytest
|
|
import os
|
|
import tempfile
|
|
from sqlmodel import SQLModel
|
|
|
|
from app.data.database import init_db, get_session, engine
|
|
from app.data.models import Session, Note, Link
|
|
|
|
|
|
class TestDatabase:
|
|
"""Test database operations."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_init_db(self):
|
|
"""Test database initialization creates tables."""
|
|
await init_db()
|
|
|
|
# Check that tables exist by attempting to create a session
|
|
async with get_session() as db:
|
|
session = Session(status="test")
|
|
db.add(session)
|
|
await db.commit()
|
|
await db.refresh(session)
|
|
assert session.id is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_context_manager(self):
|
|
"""Test database session context manager."""
|
|
await init_db()
|
|
|
|
# Test successful transaction
|
|
async with get_session() as db:
|
|
session = Session(status="test")
|
|
db.add(session)
|
|
# Context manager should auto-commit
|
|
|
|
# Verify data was saved
|
|
async with get_session() as db:
|
|
from sqlmodel import select
|
|
statement = select(Session).where(Session.status == "test")
|
|
result = await db.exec(statement)
|
|
found_session = result.first()
|
|
assert found_session is not None
|
|
assert found_session.status == "test"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_rollback_on_error(self):
|
|
"""Test that session rolls back on errors."""
|
|
await init_db()
|
|
|
|
initial_count = 0
|
|
async with get_session() as db:
|
|
from sqlmodel import select, func
|
|
result = await db.exec(select(func.count(Session.id)))
|
|
initial_count = result.one()
|
|
|
|
# Simulate an error during transaction
|
|
try:
|
|
async with get_session() as db:
|
|
session = Session(status="test_rollback")
|
|
db.add(session)
|
|
# Force an error
|
|
raise ValueError("Simulated error")
|
|
except ValueError:
|
|
pass
|
|
|
|
# Verify data was not saved due to rollback
|
|
async with get_session() as db:
|
|
from sqlmodel import select, func
|
|
result = await db.exec(select(func.count(Session.id)))
|
|
final_count = result.one()
|
|
assert final_count == initial_count # No new records |