- Define SQLModel schemas for Session, Note, and Link entities - Add API request/response models for RPC endpoints - Create LLM structured output models for Zettel extraction - Set up async database initialization with SQLModel and aiosqlite - Implement repository pattern for CRUD operations - Add complete test suite with pytest configuration - Create validation test runner for development workflow - Add .gitignore for Python/FastAPI project security
200 lines
6.5 KiB
Python
200 lines
6.5 KiB
Python
"""Test repository pattern implementations."""
|
|
import uuid
|
|
import pytest
|
|
|
|
from app.data.database import init_db, get_session
|
|
from app.data.repositories import SessionRepository, NoteRepository, LinkRepository
|
|
from app.data.models import Session, Note, Link
|
|
|
|
|
|
class TestSessionRepository:
|
|
"""Test SessionRepository operations."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_session(self):
|
|
"""Test creating a new session."""
|
|
await init_db()
|
|
|
|
async with get_session() as db:
|
|
repo = SessionRepository(db)
|
|
session = await repo.create(status="active")
|
|
|
|
assert session.id is not None
|
|
assert session.status == "active"
|
|
assert session.transcript == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_session(self):
|
|
"""Test retrieving a session by ID."""
|
|
await init_db()
|
|
|
|
# Create a session
|
|
session_id = None
|
|
async with get_session() as db:
|
|
repo = SessionRepository(db)
|
|
session = await repo.create(status="test")
|
|
session_id = session.id
|
|
|
|
# Retrieve the session
|
|
async with get_session() as db:
|
|
repo = SessionRepository(db)
|
|
retrieved = await repo.get(session_id)
|
|
|
|
assert retrieved is not None
|
|
assert retrieved.id == session_id
|
|
assert retrieved.status == "test"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_status(self):
|
|
"""Test updating session status."""
|
|
await init_db()
|
|
|
|
# Create a session
|
|
session_id = None
|
|
async with get_session() as db:
|
|
repo = SessionRepository(db)
|
|
session = await repo.create(status="active")
|
|
session_id = session.id
|
|
|
|
# Update status
|
|
async with get_session() as db:
|
|
repo = SessionRepository(db)
|
|
updated = await repo.update_status(session_id, "completed")
|
|
|
|
assert updated is not None
|
|
assert updated.status == "completed"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_append_transcript(self):
|
|
"""Test appending to session transcript."""
|
|
await init_db()
|
|
|
|
# Create a session
|
|
session_id = None
|
|
async with get_session() as db:
|
|
repo = SessionRepository(db)
|
|
session = await repo.create()
|
|
session_id = session.id
|
|
|
|
# Append to transcript
|
|
message = {"role": "user", "content": "Hello"}
|
|
async with get_session() as db:
|
|
repo = SessionRepository(db)
|
|
updated = await repo.append_transcript(session_id, message)
|
|
|
|
assert updated is not None
|
|
assert len(updated.transcript) == 1
|
|
assert updated.transcript[0] == message
|
|
|
|
|
|
class TestNoteRepository:
|
|
"""Test NoteRepository operations."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_bulk_notes(self):
|
|
"""Test creating multiple notes at once."""
|
|
await init_db()
|
|
|
|
# Create a session first
|
|
session_id = None
|
|
async with get_session() as db:
|
|
session_repo = SessionRepository(db)
|
|
session = await session_repo.create()
|
|
session_id = session.id
|
|
|
|
# Create notes
|
|
notes = [
|
|
Note(
|
|
title="Note 1",
|
|
content="Content 1",
|
|
tags=["tag1"],
|
|
session_id=session_id
|
|
),
|
|
Note(
|
|
title="Note 2",
|
|
content="Content 2",
|
|
tags=["tag2"],
|
|
session_id=session_id
|
|
)
|
|
]
|
|
|
|
async with get_session() as db:
|
|
repo = NoteRepository(db)
|
|
created_notes = await repo.create_bulk(notes)
|
|
|
|
assert len(created_notes) == 2
|
|
assert all(note.id is not None for note in created_notes)
|
|
assert created_notes[0].title == "Note 1"
|
|
assert created_notes[1].title == "Note 2"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_session(self):
|
|
"""Test retrieving notes by session ID."""
|
|
await init_db()
|
|
|
|
# Create session and notes
|
|
session_id = None
|
|
async with get_session() as db:
|
|
session_repo = SessionRepository(db)
|
|
session = await session_repo.create()
|
|
session_id = session.id
|
|
|
|
notes = [
|
|
Note(title="Note 1", content="Content 1", tags=["tag1"], session_id=session_id),
|
|
Note(title="Note 2", content="Content 2", tags=["tag2"], session_id=session_id)
|
|
]
|
|
|
|
note_repo = NoteRepository(db)
|
|
await note_repo.create_bulk(notes)
|
|
|
|
# Retrieve notes
|
|
async with get_session() as db:
|
|
repo = NoteRepository(db)
|
|
session_notes = await repo.get_by_session(session_id)
|
|
|
|
assert len(session_notes) == 2
|
|
titles = [note.title for note in session_notes]
|
|
assert "Note 1" in titles
|
|
assert "Note 2" in titles
|
|
|
|
|
|
class TestLinkRepository:
|
|
"""Test LinkRepository operations."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_bulk_links(self):
|
|
"""Test creating multiple links at once."""
|
|
await init_db()
|
|
|
|
# Create session and notes first
|
|
note_ids = []
|
|
async with get_session() as db:
|
|
session_repo = SessionRepository(db)
|
|
session = await session_repo.create()
|
|
|
|
notes = [
|
|
Note(title="Note 1", content="Content 1", tags=["tag1"], session_id=session.id),
|
|
Note(title="Note 2", content="Content 2", tags=["tag2"], session_id=session.id)
|
|
]
|
|
|
|
note_repo = NoteRepository(db)
|
|
created_notes = await note_repo.create_bulk(notes)
|
|
note_ids = [note.id for note in created_notes]
|
|
|
|
# Create links
|
|
links = [
|
|
Link(
|
|
context="These notes are related because...",
|
|
source_id=note_ids[0],
|
|
target_id=note_ids[1]
|
|
)
|
|
]
|
|
|
|
async with get_session() as db:
|
|
repo = LinkRepository(db)
|
|
created_links = await repo.create_bulk(links)
|
|
|
|
assert len(created_links) == 1
|
|
assert created_links[0].id is not None
|
|
assert created_links[0].source_id == note_ids[0]
|
|
assert created_links[0].target_id == note_ids[1] |