Files
app/app/data/repositories.py
Albert f60d61a78f feat: implement data layer with comprehensive test infrastructure
- Define SQLModel schemas for Session, Note, and Link entities
  - Add API request/response models for RPC endpoints
  - Create LLM structured output models for Zettel extraction
  - Set up async database initialization with SQLModel and aiosqlite
  - Implement repository pattern for CRUD operations
  - Add complete test suite with pytest configuration
  - Create validation test runner for development workflow
  - Add .gitignore for Python/FastAPI project security
2025-08-17 01:25:16 +00:00

98 lines
3.5 KiB
Python

import uuid
from typing import List, Optional
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.data.models import Session, Note, Link
class SessionRepository:
def __init__(self, session: AsyncSession):
self.session = session
async def create(self, status: str = "active") -> Session:
db_session = Session(status=status)
self.session.add(db_session)
await self.session.commit()
await self.session.refresh(db_session)
return db_session
async def get(self, session_id: uuid.UUID) -> Optional[Session]:
statement = select(Session).where(Session.id == session_id)
result = await self.session.exec(statement)
return result.first()
async def update_status(self, session_id: uuid.UUID, status: str) -> Optional[Session]:
db_session = await self.get(session_id)
if db_session:
db_session.status = status
await self.session.commit()
await self.session.refresh(db_session)
return db_session
async def append_transcript(self, session_id: uuid.UUID, message: dict) -> Optional[Session]:
db_session = await self.get(session_id)
if db_session:
transcript = db_session.transcript.copy()
transcript.append(message)
db_session.transcript = transcript
await self.session.commit()
await self.session.refresh(db_session)
return db_session
async def get_with_notes(self, session_id: uuid.UUID) -> Optional[Session]:
statement = select(Session).where(Session.id == session_id)
result = await self.session.exec(statement)
session = result.first()
if session:
notes_statement = select(Note).where(Note.session_id == session_id)
notes_result = await self.session.exec(notes_statement)
session.notes = notes_result.all()
return session
class NoteRepository:
def __init__(self, session: AsyncSession):
self.session = session
async def create_bulk(self, notes: List[Note]) -> List[Note]:
for note in notes:
self.session.add(note)
await self.session.commit()
for note in notes:
await self.session.refresh(note)
return notes
async def get(self, note_id: uuid.UUID) -> Optional[Note]:
statement = select(Note).where(Note.id == note_id)
result = await self.session.exec(statement)
return result.first()
async def get_by_session(self, session_id: uuid.UUID) -> List[Note]:
statement = select(Note).where(Note.session_id == session_id)
result = await self.session.exec(statement)
return result.all()
class LinkRepository:
def __init__(self, session: AsyncSession):
self.session = session
async def create_bulk(self, links: List[Link]) -> List[Link]:
for link in links:
self.session.add(link)
await self.session.commit()
for link in links:
await self.session.refresh(link)
return links
async def get_by_source(self, source_id: uuid.UUID) -> List[Link]:
statement = select(Link).where(Link.source_id == source_id)
result = await self.session.exec(statement)
return result.all()
async def get_by_target(self, target_id: uuid.UUID) -> List[Link]:
statement = select(Link).where(Link.target_id == target_id)
result = await self.session.exec(statement)
return result.all()