feat: implement data layer with comprehensive test infrastructure
- Define SQLModel schemas for Session, Note, and Link entities - Add API request/response models for RPC endpoints - Create LLM structured output models for Zettel extraction - Set up async database initialization with SQLModel and aiosqlite - Implement repository pattern for CRUD operations - Add complete test suite with pytest configuration - Create validation test runner for development workflow - Add .gitignore for Python/FastAPI project security
This commit is contained in:
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
24
tests/test_config.py
Normal file
24
tests/test_config.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Test configuration and environment setup."""
|
||||
import pytest
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
def test_config_loads():
|
||||
"""Test that configuration loads successfully."""
|
||||
assert settings.GOOGLE_API_KEY is not None
|
||||
assert settings.DATABASE_URL.startswith("sqlite+aiosqlite://")
|
||||
assert settings.CHROMA_PERSIST_DIR == "./chroma_db"
|
||||
|
||||
|
||||
def test_model_names():
|
||||
"""Test that model names are configured correctly."""
|
||||
assert "gemini" in settings.LLM_FLASH_MODEL.lower()
|
||||
assert "gemini" in settings.LLM_PRO_MODEL.lower()
|
||||
assert "embedding" in settings.EMBEDDING_MODEL.lower()
|
||||
|
||||
|
||||
def test_api_settings():
|
||||
"""Test API configuration."""
|
||||
assert isinstance(settings.API_PORT, int)
|
||||
assert settings.API_PORT > 0
|
||||
assert settings.API_HOST is not None
|
||||
73
tests/test_database.py
Normal file
73
tests/test_database.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Test database initialization and basic operations."""
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from app.data.database import init_db, get_session, engine
|
||||
from app.data.models import Session, Note, Link
|
||||
|
||||
|
||||
class TestDatabase:
|
||||
"""Test database operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db(self):
|
||||
"""Test database initialization creates tables."""
|
||||
await init_db()
|
||||
|
||||
# Check that tables exist by attempting to create a session
|
||||
async with get_session() as db:
|
||||
session = Session(status="test")
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
assert session.id is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_context_manager(self):
|
||||
"""Test database session context manager."""
|
||||
await init_db()
|
||||
|
||||
# Test successful transaction
|
||||
async with get_session() as db:
|
||||
session = Session(status="test")
|
||||
db.add(session)
|
||||
# Context manager should auto-commit
|
||||
|
||||
# Verify data was saved
|
||||
async with get_session() as db:
|
||||
from sqlmodel import select
|
||||
statement = select(Session).where(Session.status == "test")
|
||||
result = await db.exec(statement)
|
||||
found_session = result.first()
|
||||
assert found_session is not None
|
||||
assert found_session.status == "test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_rollback_on_error(self):
|
||||
"""Test that session rolls back on errors."""
|
||||
await init_db()
|
||||
|
||||
initial_count = 0
|
||||
async with get_session() as db:
|
||||
from sqlmodel import select, func
|
||||
result = await db.exec(select(func.count(Session.id)))
|
||||
initial_count = result.one()
|
||||
|
||||
# Simulate an error during transaction
|
||||
try:
|
||||
async with get_session() as db:
|
||||
session = Session(status="test_rollback")
|
||||
db.add(session)
|
||||
# Force an error
|
||||
raise ValueError("Simulated error")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Verify data was not saved due to rollback
|
||||
async with get_session() as db:
|
||||
from sqlmodel import select, func
|
||||
result = await db.exec(select(func.count(Session.id)))
|
||||
final_count = result.one()
|
||||
assert final_count == initial_count # No new records
|
||||
131
tests/test_models.py
Normal file
131
tests/test_models.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Test data models and validation."""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.data.models import (
|
||||
StartSessionRequest,
|
||||
SendMessageRequest,
|
||||
SessionResponse,
|
||||
SessionStatusResponse,
|
||||
RawZettel,
|
||||
SegmentationResult,
|
||||
RawLink,
|
||||
LinkingResult,
|
||||
Note,
|
||||
Session,
|
||||
Link
|
||||
)
|
||||
|
||||
|
||||
class TestAPIModels:
|
||||
"""Test API request/response models."""
|
||||
|
||||
def test_start_session_request(self):
|
||||
"""Test StartSessionRequest validation."""
|
||||
req = StartSessionRequest(topic="AI Ethics")
|
||||
assert req.topic == "AI Ethics"
|
||||
|
||||
# Test empty topic fails
|
||||
with pytest.raises(ValidationError):
|
||||
StartSessionRequest(topic="")
|
||||
|
||||
def test_send_message_request(self):
|
||||
"""Test SendMessageRequest validation."""
|
||||
session_id = uuid.uuid4()
|
||||
req = SendMessageRequest(session_id=session_id, message="Hello")
|
||||
assert req.session_id == session_id
|
||||
assert req.message == "Hello"
|
||||
|
||||
# Test invalid UUID fails
|
||||
with pytest.raises(ValidationError):
|
||||
SendMessageRequest(session_id="invalid", message="Hello")
|
||||
|
||||
def test_session_response(self):
|
||||
"""Test SessionResponse model."""
|
||||
session_id = uuid.uuid4()
|
||||
resp = SessionResponse(
|
||||
session_id=session_id,
|
||||
status="active",
|
||||
message="Welcome!"
|
||||
)
|
||||
assert resp.session_id == session_id
|
||||
assert resp.status == "active"
|
||||
assert resp.message == "Welcome!"
|
||||
|
||||
|
||||
class TestLLMModels:
|
||||
"""Test LLM structured output models."""
|
||||
|
||||
def test_raw_zettel(self):
|
||||
"""Test RawZettel model."""
|
||||
zettel = RawZettel(
|
||||
title="Test Concept",
|
||||
content="This is a test concept about AI.",
|
||||
tags=["ai", "test", "concept"]
|
||||
)
|
||||
assert zettel.title == "Test Concept"
|
||||
assert len(zettel.tags) == 3
|
||||
assert "ai" in zettel.tags
|
||||
|
||||
def test_segmentation_result(self):
|
||||
"""Test SegmentationResult model."""
|
||||
zettels = [
|
||||
RawZettel(title="Concept 1", content="Content 1", tags=["tag1"]),
|
||||
RawZettel(title="Concept 2", content="Content 2", tags=["tag2"])
|
||||
]
|
||||
result = SegmentationResult(notes=zettels)
|
||||
assert len(result.notes) == 2
|
||||
assert result.notes[0].title == "Concept 1"
|
||||
|
||||
def test_raw_link(self):
|
||||
"""Test RawLink model."""
|
||||
link = RawLink(
|
||||
target_note_id=str(uuid.uuid4()),
|
||||
relationship_context="Related through shared AI concepts"
|
||||
)
|
||||
assert uuid.UUID(link.target_note_id) # Should not raise
|
||||
assert "AI" in link.relationship_context
|
||||
|
||||
|
||||
class TestDatabaseModels:
|
||||
"""Test database models."""
|
||||
|
||||
def test_session_creation(self):
|
||||
"""Test Session model creation."""
|
||||
session = Session()
|
||||
assert isinstance(session.id, uuid.UUID)
|
||||
assert session.status == "active"
|
||||
assert session.transcript == []
|
||||
assert isinstance(session.created_at, datetime)
|
||||
|
||||
def test_note_creation(self):
|
||||
"""Test Note model creation."""
|
||||
session_id = uuid.uuid4()
|
||||
note = Note(
|
||||
title="Test Note",
|
||||
content="This is test content",
|
||||
tags=["test", "note"],
|
||||
session_id=session_id
|
||||
)
|
||||
assert isinstance(note.id, uuid.UUID)
|
||||
assert note.title == "Test Note"
|
||||
assert note.content == "This is test content"
|
||||
assert note.tags == ["test", "note"]
|
||||
assert note.session_id == session_id
|
||||
assert isinstance(note.created_at, datetime)
|
||||
|
||||
def test_link_creation(self):
|
||||
"""Test Link model creation."""
|
||||
source_id = uuid.uuid4()
|
||||
target_id = uuid.uuid4()
|
||||
link = Link(
|
||||
context="These concepts are related through...",
|
||||
source_id=source_id,
|
||||
target_id=target_id
|
||||
)
|
||||
assert link.context == "These concepts are related through..."
|
||||
assert link.source_id == source_id
|
||||
assert link.target_id == target_id
|
||||
200
tests/test_repositories.py
Normal file
200
tests/test_repositories.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Test repository pattern implementations."""
|
||||
import uuid
|
||||
import pytest
|
||||
|
||||
from app.data.database import init_db, get_session
|
||||
from app.data.repositories import SessionRepository, NoteRepository, LinkRepository
|
||||
from app.data.models import Session, Note, Link
|
||||
|
||||
|
||||
class TestSessionRepository:
|
||||
"""Test SessionRepository operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session(self):
|
||||
"""Test creating a new session."""
|
||||
await init_db()
|
||||
|
||||
async with get_session() as db:
|
||||
repo = SessionRepository(db)
|
||||
session = await repo.create(status="active")
|
||||
|
||||
assert session.id is not None
|
||||
assert session.status == "active"
|
||||
assert session.transcript == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session(self):
|
||||
"""Test retrieving a session by ID."""
|
||||
await init_db()
|
||||
|
||||
# Create a session
|
||||
session_id = None
|
||||
async with get_session() as db:
|
||||
repo = SessionRepository(db)
|
||||
session = await repo.create(status="test")
|
||||
session_id = session.id
|
||||
|
||||
# Retrieve the session
|
||||
async with get_session() as db:
|
||||
repo = SessionRepository(db)
|
||||
retrieved = await repo.get(session_id)
|
||||
|
||||
assert retrieved is not None
|
||||
assert retrieved.id == session_id
|
||||
assert retrieved.status == "test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_status(self):
|
||||
"""Test updating session status."""
|
||||
await init_db()
|
||||
|
||||
# Create a session
|
||||
session_id = None
|
||||
async with get_session() as db:
|
||||
repo = SessionRepository(db)
|
||||
session = await repo.create(status="active")
|
||||
session_id = session.id
|
||||
|
||||
# Update status
|
||||
async with get_session() as db:
|
||||
repo = SessionRepository(db)
|
||||
updated = await repo.update_status(session_id, "completed")
|
||||
|
||||
assert updated is not None
|
||||
assert updated.status == "completed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_transcript(self):
|
||||
"""Test appending to session transcript."""
|
||||
await init_db()
|
||||
|
||||
# Create a session
|
||||
session_id = None
|
||||
async with get_session() as db:
|
||||
repo = SessionRepository(db)
|
||||
session = await repo.create()
|
||||
session_id = session.id
|
||||
|
||||
# Append to transcript
|
||||
message = {"role": "user", "content": "Hello"}
|
||||
async with get_session() as db:
|
||||
repo = SessionRepository(db)
|
||||
updated = await repo.append_transcript(session_id, message)
|
||||
|
||||
assert updated is not None
|
||||
assert len(updated.transcript) == 1
|
||||
assert updated.transcript[0] == message
|
||||
|
||||
|
||||
class TestNoteRepository:
|
||||
"""Test NoteRepository operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_bulk_notes(self):
|
||||
"""Test creating multiple notes at once."""
|
||||
await init_db()
|
||||
|
||||
# Create a session first
|
||||
session_id = None
|
||||
async with get_session() as db:
|
||||
session_repo = SessionRepository(db)
|
||||
session = await session_repo.create()
|
||||
session_id = session.id
|
||||
|
||||
# Create notes
|
||||
notes = [
|
||||
Note(
|
||||
title="Note 1",
|
||||
content="Content 1",
|
||||
tags=["tag1"],
|
||||
session_id=session_id
|
||||
),
|
||||
Note(
|
||||
title="Note 2",
|
||||
content="Content 2",
|
||||
tags=["tag2"],
|
||||
session_id=session_id
|
||||
)
|
||||
]
|
||||
|
||||
async with get_session() as db:
|
||||
repo = NoteRepository(db)
|
||||
created_notes = await repo.create_bulk(notes)
|
||||
|
||||
assert len(created_notes) == 2
|
||||
assert all(note.id is not None for note in created_notes)
|
||||
assert created_notes[0].title == "Note 1"
|
||||
assert created_notes[1].title == "Note 2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_session(self):
|
||||
"""Test retrieving notes by session ID."""
|
||||
await init_db()
|
||||
|
||||
# Create session and notes
|
||||
session_id = None
|
||||
async with get_session() as db:
|
||||
session_repo = SessionRepository(db)
|
||||
session = await session_repo.create()
|
||||
session_id = session.id
|
||||
|
||||
notes = [
|
||||
Note(title="Note 1", content="Content 1", tags=["tag1"], session_id=session_id),
|
||||
Note(title="Note 2", content="Content 2", tags=["tag2"], session_id=session_id)
|
||||
]
|
||||
|
||||
note_repo = NoteRepository(db)
|
||||
await note_repo.create_bulk(notes)
|
||||
|
||||
# Retrieve notes
|
||||
async with get_session() as db:
|
||||
repo = NoteRepository(db)
|
||||
session_notes = await repo.get_by_session(session_id)
|
||||
|
||||
assert len(session_notes) == 2
|
||||
titles = [note.title for note in session_notes]
|
||||
assert "Note 1" in titles
|
||||
assert "Note 2" in titles
|
||||
|
||||
|
||||
class TestLinkRepository:
|
||||
"""Test LinkRepository operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_bulk_links(self):
|
||||
"""Test creating multiple links at once."""
|
||||
await init_db()
|
||||
|
||||
# Create session and notes first
|
||||
note_ids = []
|
||||
async with get_session() as db:
|
||||
session_repo = SessionRepository(db)
|
||||
session = await session_repo.create()
|
||||
|
||||
notes = [
|
||||
Note(title="Note 1", content="Content 1", tags=["tag1"], session_id=session.id),
|
||||
Note(title="Note 2", content="Content 2", tags=["tag2"], session_id=session.id)
|
||||
]
|
||||
|
||||
note_repo = NoteRepository(db)
|
||||
created_notes = await note_repo.create_bulk(notes)
|
||||
note_ids = [note.id for note in created_notes]
|
||||
|
||||
# Create links
|
||||
links = [
|
||||
Link(
|
||||
context="These notes are related because...",
|
||||
source_id=note_ids[0],
|
||||
target_id=note_ids[1]
|
||||
)
|
||||
]
|
||||
|
||||
async with get_session() as db:
|
||||
repo = LinkRepository(db)
|
||||
created_links = await repo.create_bulk(links)
|
||||
|
||||
assert len(created_links) == 1
|
||||
assert created_links[0].id is not None
|
||||
assert created_links[0].source_id == note_ids[0]
|
||||
assert created_links[0].target_id == note_ids[1]
|
||||
Reference in New Issue
Block a user