- Define SQLModel schemas for Session, Note, and Link entities - Add API request/response models for RPC endpoints - Create LLM structured output models for Zettel extraction - Set up async database initialization with SQLModel and aiosqlite - Implement repository pattern for CRUD operations - Add complete test suite with pytest configuration - Create validation test runner for development workflow - Add .gitignore for Python/FastAPI project security
98 lines
3.5 KiB
Python
98 lines
3.5 KiB
Python
import uuid
|
|
from typing import List, Optional
|
|
|
|
from sqlmodel import select
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.data.models import Session, Note, Link
|
|
|
|
|
|
class SessionRepository:
|
|
def __init__(self, session: AsyncSession):
|
|
self.session = session
|
|
|
|
async def create(self, status: str = "active") -> Session:
|
|
db_session = Session(status=status)
|
|
self.session.add(db_session)
|
|
await self.session.commit()
|
|
await self.session.refresh(db_session)
|
|
return db_session
|
|
|
|
async def get(self, session_id: uuid.UUID) -> Optional[Session]:
|
|
statement = select(Session).where(Session.id == session_id)
|
|
result = await self.session.exec(statement)
|
|
return result.first()
|
|
|
|
async def update_status(self, session_id: uuid.UUID, status: str) -> Optional[Session]:
|
|
db_session = await self.get(session_id)
|
|
if db_session:
|
|
db_session.status = status
|
|
await self.session.commit()
|
|
await self.session.refresh(db_session)
|
|
return db_session
|
|
|
|
async def append_transcript(self, session_id: uuid.UUID, message: dict) -> Optional[Session]:
|
|
db_session = await self.get(session_id)
|
|
if db_session:
|
|
transcript = db_session.transcript.copy()
|
|
transcript.append(message)
|
|
db_session.transcript = transcript
|
|
await self.session.commit()
|
|
await self.session.refresh(db_session)
|
|
return db_session
|
|
|
|
async def get_with_notes(self, session_id: uuid.UUID) -> Optional[Session]:
|
|
statement = select(Session).where(Session.id == session_id)
|
|
result = await self.session.exec(statement)
|
|
session = result.first()
|
|
if session:
|
|
notes_statement = select(Note).where(Note.session_id == session_id)
|
|
notes_result = await self.session.exec(notes_statement)
|
|
session.notes = notes_result.all()
|
|
return session
|
|
|
|
|
|
class NoteRepository:
|
|
def __init__(self, session: AsyncSession):
|
|
self.session = session
|
|
|
|
async def create_bulk(self, notes: List[Note]) -> List[Note]:
|
|
for note in notes:
|
|
self.session.add(note)
|
|
await self.session.commit()
|
|
for note in notes:
|
|
await self.session.refresh(note)
|
|
return notes
|
|
|
|
async def get(self, note_id: uuid.UUID) -> Optional[Note]:
|
|
statement = select(Note).where(Note.id == note_id)
|
|
result = await self.session.exec(statement)
|
|
return result.first()
|
|
|
|
async def get_by_session(self, session_id: uuid.UUID) -> List[Note]:
|
|
statement = select(Note).where(Note.session_id == session_id)
|
|
result = await self.session.exec(statement)
|
|
return result.all()
|
|
|
|
|
|
class LinkRepository:
|
|
def __init__(self, session: AsyncSession):
|
|
self.session = session
|
|
|
|
async def create_bulk(self, links: List[Link]) -> List[Link]:
|
|
for link in links:
|
|
self.session.add(link)
|
|
await self.session.commit()
|
|
for link in links:
|
|
await self.session.refresh(link)
|
|
return links
|
|
|
|
async def get_by_source(self, source_id: uuid.UUID) -> List[Link]:
|
|
statement = select(Link).where(Link.source_id == source_id)
|
|
result = await self.session.exec(statement)
|
|
return result.all()
|
|
|
|
async def get_by_target(self, target_id: uuid.UUID) -> List[Link]:
|
|
statement = select(Link).where(Link.target_id == target_id)
|
|
result = await self.session.exec(statement)
|
|
return result.all() |