feat: implement data layer with comprehensive test infrastructure
- Define SQLModel schemas for Session, Note, and Link entities - Add API request/response models for RPC endpoints - Create LLM structured output models for Zettel extraction - Set up async database initialization with SQLModel and aiosqlite - Implement repository pattern for CRUD operations - Add complete test suite with pytest configuration - Create validation test runner for development workflow - Add .gitignore for Python/FastAPI project security
This commit is contained in:
12
.claude/settings.local.json
Normal file
12
.claude/settings.local.json
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
{
|
||||||
|
"permissions": {
|
||||||
|
"allow": [
|
||||||
|
"Bash(mkdir:*)",
|
||||||
|
"Bash(python run_tests.py:*)",
|
||||||
|
"Bash(source:*)",
|
||||||
|
"Bash(pytest:*)"
|
||||||
|
],
|
||||||
|
"deny": [],
|
||||||
|
"ask": []
|
||||||
|
}
|
||||||
|
}
|
||||||
104
.gitignore
vendored
Normal file
104
.gitignore
vendored
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# IDE specific files
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
|
||||||
|
# OS specific files
|
||||||
|
.DS_Store
|
||||||
|
.DS_Store?
|
||||||
|
._*
|
||||||
|
.Spotlight-V100
|
||||||
|
.Trashes
|
||||||
|
ehthumbs.db
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# Project specific files
|
||||||
|
*.db
|
||||||
|
*.sqlite
|
||||||
|
*.sqlite3
|
||||||
|
skytalk.db*
|
||||||
|
|
||||||
|
# ChromaDB vector database
|
||||||
|
chroma_db/
|
||||||
|
*.chroma
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
*.log
|
||||||
|
logs/
|
||||||
|
|
||||||
|
# Environment variables (keep template)
|
||||||
|
.env.local
|
||||||
|
.env.production
|
||||||
|
.env.development
|
||||||
|
|
||||||
|
# FastAPI specific
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Jupyter Notebooks
|
||||||
|
.ipynb_checkpoints
|
||||||
|
*.ipynb
|
||||||
|
|
||||||
|
# Documentation builds
|
||||||
|
docs/_build/
|
||||||
|
site/
|
||||||
116
CLAUDE.md
Normal file
116
CLAUDE.md
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
# CLAUDE.md
|
||||||
|
|
||||||
|
This file provides guidance to Claude Code (claude.ai/code) when working with
|
||||||
|
code in this repository.
|
||||||
|
|
||||||
|
## Project Overview
|
||||||
|
|
||||||
|
SkyTalk API is an AI-powered backend service for conversational idea exploration
|
||||||
|
and knowledge synthesis. It interviews users to explore ideas and synthesizes
|
||||||
|
conversations into a structured, semantically-linked knowledge base using the
|
||||||
|
Zettelkasten method.
|
||||||
|
|
||||||
|
## Core Development Commands
|
||||||
|
|
||||||
|
### Environment Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install dependencies using uv package manager
|
||||||
|
uv pip install -r requirements.txt
|
||||||
|
|
||||||
|
# Run the API server (async FastAPI)
|
||||||
|
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||||
|
```
|
||||||
|
|
||||||
|
### Code Quality
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Format code with black (mandatory)
|
||||||
|
black .
|
||||||
|
|
||||||
|
# Type checking
|
||||||
|
mypy .
|
||||||
|
|
||||||
|
# Linting
|
||||||
|
ruff check .
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture & Implementation Standards
|
||||||
|
|
||||||
|
### Three-Layer Architecture
|
||||||
|
|
||||||
|
1. **API Layer**: FastAPI RPC-style endpoints (`/sessions/start`,
|
||||||
|
`/sessions/message`)
|
||||||
|
2. **Services Layer**: Orchestration logic with LangChain LCEL
|
||||||
|
3. **Data Layer**: SQLModel (SQLite) + ChromaDB (embedded, persisted)
|
||||||
|
|
||||||
|
### Critical Implementation Rules
|
||||||
|
|
||||||
|
**Database Models**:
|
||||||
|
|
||||||
|
- SQLModel is the single source of truth for all database schemas
|
||||||
|
- ChromaDB runs in embedded mode, persisted to disk
|
||||||
|
- All database operations must be async
|
||||||
|
|
||||||
|
**AI Integration**:
|
||||||
|
|
||||||
|
- **Interviewing**: Use `gemini-2.5-flash-latest` (optimized for speed)
|
||||||
|
- **Synthesis/Linking**: Use `gemini-2.5-pro-latest` (optimized for reasoning)
|
||||||
|
- **Embeddings**: Use `models/text-embedding-004`
|
||||||
|
- **Structured Output**: Always use `.with_structured_output()` with Pydantic
|
||||||
|
models for data extraction - never parse raw text
|
||||||
|
|
||||||
|
**Code Standards**:
|
||||||
|
|
||||||
|
- Maximum 400 lines per Python file
|
||||||
|
- Full type hints required (Pydantic V2 for API, SQLModel for DB)
|
||||||
|
- All I/O operations must use async/await
|
||||||
|
- Configuration via environment variables with `pydantic-settings`
|
||||||
|
- Use `HTTPException` for client errors
|
||||||
|
|
||||||
|
### Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
api/
|
||||||
|
├── app/
|
||||||
|
│ ├── api/ # FastAPI endpoints (RPC-style)
|
||||||
|
│ ├── services/ # Business logic, LangChain agents
|
||||||
|
│ │ ├── interviewer.py
|
||||||
|
│ │ ├── synthesizer.py
|
||||||
|
│ │ └── vector.py
|
||||||
|
│ ├── data/ # Repositories and database models
|
||||||
|
│ │ ├── models/ # SQLModel definitions
|
||||||
|
│ │ └── repositories/
|
||||||
|
│ ├── core/ # Configuration, prompts
|
||||||
|
│ └── main.py # FastAPI app initialization
|
||||||
|
├── requirements.txt # Dependencies managed via uv
|
||||||
|
└── .env # Environment variables
|
||||||
|
```
|
||||||
|
|
||||||
|
### Key Implementation Patterns
|
||||||
|
|
||||||
|
**LangChain LCEL Pipeline**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
chain = prompt | llm.with_structured_output(OutputModel) | parser
|
||||||
|
result = await chain.ainvoke({"input": data})
|
||||||
|
```
|
||||||
|
|
||||||
|
**Async Database Operations**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def get_session(session_id: str) -> Session:
|
||||||
|
async with get_session() as db:
|
||||||
|
return await db.get(Session, session_id)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Background Task for Synthesis**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
background_tasks.add_task(synthesize_session, session_id)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Overall Style
|
||||||
|
|
||||||
|
- NO sycophancy -- push back/suggest alternative routes when it would help
|
||||||
|
improve the project
|
||||||
57
README.md
57
README.md
@@ -81,3 +81,60 @@ SkyTalkAPI: {
|
|||||||
# Connections to External APIs
|
# Connections to External APIs
|
||||||
SkyTalkAPI.Services -> ExternalAPIs.Gemini
|
SkyTalkAPI.Services -> ExternalAPIs.Gemini
|
||||||
```
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## MVP Features (Phase 1)
|
||||||
|
|
||||||
|
At the completion of this implementation plan, the SkyTalk API will support the
|
||||||
|
following core features:
|
||||||
|
|
||||||
|
- **Session Management:** Start a new interview session based on an initial
|
||||||
|
topic.
|
||||||
|
|
||||||
|
- **RAG-Powered Interviewing:** Engage in a back-and-forth conversation where
|
||||||
|
the AI's questions are informed by existing knowledge in the vector store.
|
||||||
|
|
||||||
|
- **Automatic Session Termination:** The AI can detect a natural conclusion to
|
||||||
|
the conversation.
|
||||||
|
|
||||||
|
- **Asynchronous Synthesis:** Once the interview ends, a background process is
|
||||||
|
triggered to analyze the transcript.
|
||||||
|
|
||||||
|
- **Semantic Segmentation:** The transcript is intelligently broken down into
|
||||||
|
atomic "Zettels" (notes), each focusing on a single concept.
|
||||||
|
|
||||||
|
- **Vector Indexing:** Each new note is converted into a vector embedding and
|
||||||
|
stored for future RAG.
|
||||||
|
|
||||||
|
- **Generative Linking:** The system identifies semantically related notes and
|
||||||
|
uses an LLM to generate a rich, contextual link explaining the relationship
|
||||||
|
between them.
|
||||||
|
|
||||||
|
- **Status Tracking:** Endpoints to check the status of a session (active,
|
||||||
|
processing, completed).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Next Steps (Post-MVP)
|
||||||
|
|
||||||
|
- **Authentication:** Implement user accounts and authentication (e.g., JWT) to
|
||||||
|
create user-specific knowledge bases.
|
||||||
|
|
||||||
|
- **Frontend Integration:** Build a web-based frontend (e.g., at
|
||||||
|
www.skytalk.app) that consumes this API.
|
||||||
|
|
||||||
|
- **Knowledge Graph Visualization:** Add endpoints to export the note-and-link
|
||||||
|
structure in a format suitable for graph visualization libraries (e.g., D3.js,
|
||||||
|
Vis.js).
|
||||||
|
|
||||||
|
- **Note Editing and Management:** Provide endpoints for users to manually edit,
|
||||||
|
delete, or merge notes.
|
||||||
|
|
||||||
|
- **Advanced Search:** Implement more sophisticated search functionalities
|
||||||
|
beyond simple semantic search, such as filtering by tags or searching within
|
||||||
|
link contexts.
|
||||||
|
|
||||||
|
- **Scalable Infrastructure:** Migrate from SQLite/embedded ChromaDB to a
|
||||||
|
production-grade database (e.g., PostgreSQL with pgvector) and a managed
|
||||||
|
vector database for scalability.
|
||||||
|
|||||||
19
app/core/config.py
Normal file
19
app/core/config.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||||
|
|
||||||
|
GOOGLE_API_KEY: str
|
||||||
|
DATABASE_URL: str = "sqlite+aiosqlite:///./skytalk.db"
|
||||||
|
CHROMA_PERSIST_DIR: str = "./chroma_db"
|
||||||
|
|
||||||
|
LLM_FLASH_MODEL: str = "gemini-2.5-flash-latest"
|
||||||
|
LLM_PRO_MODEL: str = "gemini-2.5-pro-latest"
|
||||||
|
EMBEDDING_MODEL: str = "models/text-embedding-004"
|
||||||
|
|
||||||
|
API_HOST: str = "0.0.0.0"
|
||||||
|
API_PORT: int = 8000
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
33
app/data/database.py
Normal file
33
app/data/database.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
|
from sqlmodel import SQLModel
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
engine = create_async_engine(
|
||||||
|
settings.DATABASE_URL,
|
||||||
|
echo=False,
|
||||||
|
future=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def init_db() -> None:
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(SQLModel.metadata.create_all)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
async with AsyncSession(engine) as session:
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
await session.commit()
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
73
app/data/models.py
Normal file
73
app/data/models.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Optional, Any
|
||||||
|
from sqlmodel import Field, SQLModel, Relationship, JSON, Column
|
||||||
|
|
||||||
|
|
||||||
|
class StartSessionRequest(SQLModel):
|
||||||
|
topic: str = Field(min_length=1, description="Topic for the session")
|
||||||
|
|
||||||
|
|
||||||
|
class SendMessageRequest(SQLModel):
|
||||||
|
session_id: uuid.UUID
|
||||||
|
message: str = Field(min_length=1, description="User message content")
|
||||||
|
|
||||||
|
|
||||||
|
class SessionResponse(SQLModel):
|
||||||
|
session_id: uuid.UUID
|
||||||
|
status: str
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class SessionStatusResponse(SQLModel):
|
||||||
|
session_id: uuid.UUID
|
||||||
|
status: str
|
||||||
|
notes_count: int = 0
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class RawZettel(SQLModel):
|
||||||
|
title: str = Field(description="Concise title of the atomic concept.")
|
||||||
|
content: str = Field(description="The synthesized mini-blog post.")
|
||||||
|
tags: List[str] = Field(description="List of relevant conceptual keywords.")
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentationResult(SQLModel):
|
||||||
|
notes: List[RawZettel]
|
||||||
|
|
||||||
|
|
||||||
|
class RawLink(SQLModel):
|
||||||
|
target_note_id: str = Field(description="UUID of the related note.")
|
||||||
|
relationship_context: str = Field(description="Explanation of how the new note relates to the target note.")
|
||||||
|
|
||||||
|
|
||||||
|
class LinkingResult(SQLModel):
|
||||||
|
links: List[RawLink]
|
||||||
|
|
||||||
|
|
||||||
|
class Link(SQLModel, table=True):
|
||||||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
|
context: str
|
||||||
|
|
||||||
|
source_id: uuid.UUID = Field(foreign_key="note.id")
|
||||||
|
target_id: uuid.UUID = Field(foreign_key="note.id")
|
||||||
|
|
||||||
|
|
||||||
|
class Note(SQLModel, table=True):
|
||||||
|
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||||
|
title: str
|
||||||
|
content: str
|
||||||
|
tags: List[str] = Field(sa_column=Column(JSON))
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
session_id: uuid.UUID = Field(foreign_key="session.id")
|
||||||
|
session: Optional["Session"] = Relationship(back_populates="notes")
|
||||||
|
|
||||||
|
|
||||||
|
class Session(SQLModel, table=True):
|
||||||
|
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||||
|
status: str = Field(default="active")
|
||||||
|
transcript: List[dict] = Field(default=[], sa_column=Column(JSON))
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
notes: List["Note"] = Relationship(back_populates="session")
|
||||||
98
app/data/repositories.py
Normal file
98
app/data/repositories.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
import uuid
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from sqlmodel import select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.data.models import Session, Note, Link
|
||||||
|
|
||||||
|
|
||||||
|
class SessionRepository:
|
||||||
|
def __init__(self, session: AsyncSession):
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
async def create(self, status: str = "active") -> Session:
|
||||||
|
db_session = Session(status=status)
|
||||||
|
self.session.add(db_session)
|
||||||
|
await self.session.commit()
|
||||||
|
await self.session.refresh(db_session)
|
||||||
|
return db_session
|
||||||
|
|
||||||
|
async def get(self, session_id: uuid.UUID) -> Optional[Session]:
|
||||||
|
statement = select(Session).where(Session.id == session_id)
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
return result.first()
|
||||||
|
|
||||||
|
async def update_status(self, session_id: uuid.UUID, status: str) -> Optional[Session]:
|
||||||
|
db_session = await self.get(session_id)
|
||||||
|
if db_session:
|
||||||
|
db_session.status = status
|
||||||
|
await self.session.commit()
|
||||||
|
await self.session.refresh(db_session)
|
||||||
|
return db_session
|
||||||
|
|
||||||
|
async def append_transcript(self, session_id: uuid.UUID, message: dict) -> Optional[Session]:
|
||||||
|
db_session = await self.get(session_id)
|
||||||
|
if db_session:
|
||||||
|
transcript = db_session.transcript.copy()
|
||||||
|
transcript.append(message)
|
||||||
|
db_session.transcript = transcript
|
||||||
|
await self.session.commit()
|
||||||
|
await self.session.refresh(db_session)
|
||||||
|
return db_session
|
||||||
|
|
||||||
|
async def get_with_notes(self, session_id: uuid.UUID) -> Optional[Session]:
|
||||||
|
statement = select(Session).where(Session.id == session_id)
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
session = result.first()
|
||||||
|
if session:
|
||||||
|
notes_statement = select(Note).where(Note.session_id == session_id)
|
||||||
|
notes_result = await self.session.exec(notes_statement)
|
||||||
|
session.notes = notes_result.all()
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
class NoteRepository:
|
||||||
|
def __init__(self, session: AsyncSession):
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
async def create_bulk(self, notes: List[Note]) -> List[Note]:
|
||||||
|
for note in notes:
|
||||||
|
self.session.add(note)
|
||||||
|
await self.session.commit()
|
||||||
|
for note in notes:
|
||||||
|
await self.session.refresh(note)
|
||||||
|
return notes
|
||||||
|
|
||||||
|
async def get(self, note_id: uuid.UUID) -> Optional[Note]:
|
||||||
|
statement = select(Note).where(Note.id == note_id)
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
return result.first()
|
||||||
|
|
||||||
|
async def get_by_session(self, session_id: uuid.UUID) -> List[Note]:
|
||||||
|
statement = select(Note).where(Note.session_id == session_id)
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
return result.all()
|
||||||
|
|
||||||
|
|
||||||
|
class LinkRepository:
|
||||||
|
def __init__(self, session: AsyncSession):
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
async def create_bulk(self, links: List[Link]) -> List[Link]:
|
||||||
|
for link in links:
|
||||||
|
self.session.add(link)
|
||||||
|
await self.session.commit()
|
||||||
|
for link in links:
|
||||||
|
await self.session.refresh(link)
|
||||||
|
return links
|
||||||
|
|
||||||
|
async def get_by_source(self, source_id: uuid.UUID) -> List[Link]:
|
||||||
|
statement = select(Link).where(Link.source_id == source_id)
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
return result.all()
|
||||||
|
|
||||||
|
async def get_by_target(self, target_id: uuid.UUID) -> List[Link]:
|
||||||
|
statement = select(Link).where(Link.target_id == target_id)
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
return result.all()
|
||||||
241
implementation-plan.txt
Normal file
241
implementation-plan.txt
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
Product Requirements Document (PRD) & Implementation Plan
|
||||||
|
1. Overview
|
||||||
|
SkyTalk (api.skytalk.app) is an API service that functions as an AI interviewer to elicit user ideas, then processes the conversation to generate "Zettels" (atomic mini-blog posts). These Zettels are semantically linked using RAG and "Generative Linking."
|
||||||
|
|
||||||
|
2. Implementation Tasks (Linear Execution Plan)
|
||||||
|
Phase 1: Project Setup and Core Utilities
|
||||||
|
Task 1.1: Project Structure Initialization
|
||||||
|
Create the following directory and file structure:
|
||||||
|
|
||||||
|
Bash
|
||||||
|
|
||||||
|
skytalk-api/
|
||||||
|
├── .venv/
|
||||||
|
├── src/
|
||||||
|
│ ├── config.py
|
||||||
|
│ ├── db/
|
||||||
|
│ │ ├── database.py
|
||||||
|
│ │ └── repositories.py
|
||||||
|
│ ├── services/
|
||||||
|
│ │ ├── agents/
|
||||||
|
│ │ │ ├── interviewer.py
|
||||||
|
│ │ │ ├── synthesizer.py
|
||||||
|
│ │ │ └── prompts.py
|
||||||
|
│ │ ├── session_service.py
|
||||||
|
│ │ └── vector_service.py
|
||||||
|
│ ├── models.py
|
||||||
|
│ └── main.py
|
||||||
|
├── .env
|
||||||
|
└── requirements.txt
|
||||||
|
Task 1.2: Dependencies and Configuration
|
||||||
|
|
||||||
|
Create requirements.txt:
|
||||||
|
|
||||||
|
Plaintext
|
||||||
|
|
||||||
|
fastapi
|
||||||
|
uvicorn[standard]
|
||||||
|
pydantic-settings
|
||||||
|
sqlmodel
|
||||||
|
aiosqlite
|
||||||
|
langchain
|
||||||
|
langchain-google-genai
|
||||||
|
langchain-community
|
||||||
|
chromadb
|
||||||
|
tiktoken
|
||||||
|
Create src/config.py:
|
||||||
|
|
||||||
|
Python
|
||||||
|
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||||
|
|
||||||
|
GOOGLE_API_KEY: str
|
||||||
|
DATABASE_URL: str = "sqlite+aiosqlite:///./skytalk.db"
|
||||||
|
CHROMA_PERSIST_DIR: str = "./chroma_db"
|
||||||
|
|
||||||
|
LLM_FLASH_MODEL: str = "gemini-2.5-flash-latest"
|
||||||
|
LLM_PRO_MODEL: str = "gemini-2.5-pro-latest"
|
||||||
|
EMBEDDING_MODEL: str = "models/text-embedding-004"
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
Phase 2: Data Models and Persistence
|
||||||
|
Task 2.1: SQLModel Definitions (src/models.py)
|
||||||
|
Define all data structures, combining database models and API schemas.
|
||||||
|
|
||||||
|
Python
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Optional, Any
|
||||||
|
from sqlmodel import Field, SQLModel, Relationship, JSON, Column
|
||||||
|
|
||||||
|
# --- API Payloads (Not stored in DB) ---
|
||||||
|
class StartSessionRequest(SQLModel):
|
||||||
|
topic: str
|
||||||
|
|
||||||
|
class SendMessageRequest(SQLModel):
|
||||||
|
session_id: uuid.UUID
|
||||||
|
message: str
|
||||||
|
|
||||||
|
# --- LLM Structured Output Models ---
|
||||||
|
class RawZettel(SQLModel):
|
||||||
|
title: str = Field(description="Concise title of the atomic concept.")
|
||||||
|
content: str = Field(description="The synthesized mini-blog post.")
|
||||||
|
tags: List[str] = Field(description="List of relevant conceptual keywords.")
|
||||||
|
|
||||||
|
class SegmentationResult(SQLModel):
|
||||||
|
notes: List[RawZettel]
|
||||||
|
|
||||||
|
class RawLink(SQLModel):
|
||||||
|
target_note_id: str = Field(description="UUID of the related note.")
|
||||||
|
relationship_context: str = Field(description="Explanation of how the new note relates to the target note.")
|
||||||
|
|
||||||
|
class LinkingResult(SQLModel):
|
||||||
|
links: List[RawLink]
|
||||||
|
|
||||||
|
# --- Database Models ---
|
||||||
|
class Link(SQLModel, table=True):
|
||||||
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
|
context: str
|
||||||
|
|
||||||
|
source_id: uuid.UUID = Field(foreign_key="note.id")
|
||||||
|
target_id: uuid.UUID = Field(foreign_key="note.id")
|
||||||
|
|
||||||
|
class Note(SQLModel, table=True):
|
||||||
|
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||||
|
title: str
|
||||||
|
content: str
|
||||||
|
tags: List[str] = Field(sa_column=Column(JSON))
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
session_id: uuid.UUID = Field(foreign_key="session.id")
|
||||||
|
session: "Session" = Relationship(back_populates="notes")
|
||||||
|
|
||||||
|
class Session(SQLModel, table=True):
|
||||||
|
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||||
|
status: str = Field(default="active") # active, processing, completed, failed
|
||||||
|
transcript: List[dict] = Field(default=[], sa_column=Column(JSON))
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
notes: List["Note"] = Relationship(back_populates="session")
|
||||||
|
Task 2.2: Database Initialization (src/db/database.py)
|
||||||
|
|
||||||
|
Implement create_async_engine from sqlmodel.ext.asyncio.session.
|
||||||
|
|
||||||
|
Define AsyncSession and init_db() to create tables on startup.
|
||||||
|
|
||||||
|
Implement a FastAPI dependency get_session() to inject the AsyncSession.
|
||||||
|
|
||||||
|
Task 2.3: Repositories (src/db/repositories.py)
|
||||||
|
Implement async CRUD operations using SQLModel's syntax.
|
||||||
|
|
||||||
|
SessionRepository: create, get, update_status, append_transcript.
|
||||||
|
|
||||||
|
NoteRepository: create_bulk (efficiently inserts multiple notes).
|
||||||
|
|
||||||
|
LinkRepository: create_bulk.
|
||||||
|
|
||||||
|
Phase 3: Core Services
|
||||||
|
Task 3.1: Vector Store Service (src/services/vector_service.py)
|
||||||
|
|
||||||
|
Initialize GoogleGenerativeAIEmbeddings from langchain_google_genai.
|
||||||
|
|
||||||
|
Initialize a persistent Chroma client.
|
||||||
|
|
||||||
|
async add_notes(notes: List[Note]): Adds note content and metadata to Chroma, using note ids as document IDs.
|
||||||
|
|
||||||
|
async semantic_search(query: str, k: int = 5): Performs similarity search and returns LangChain Document objects.
|
||||||
|
|
||||||
|
Task 3.2: Prompts (src/services/agents/prompts.py)
|
||||||
|
|
||||||
|
INTERVIEWER_SYSTEM_PROMPT: Socratic persona. Accepts {retrieved_context}. Instructs AI to ask one question at a time and to output the exact token [END_SESSION] when the conversation naturally concludes.
|
||||||
|
|
||||||
|
SEGMENTATION_PROMPT: Instructs LLM to analyze {transcript} and segment it into atomic ideas, enforcing the SegmentationResult structure.
|
||||||
|
|
||||||
|
GENERATIVE_LINKING_PROMPT: Instructs LLM to compare {new_note} against {neighbors}, enforcing the LinkingResult structure.
|
||||||
|
|
||||||
|
Phase 4: The Agent Layer
|
||||||
|
Task 4.1: Interviewer Agent (src/services/agents/interviewer.py)
|
||||||
|
|
||||||
|
Initialize the Interview LLM (Gemini Flash) using ChatGoogleGenerativeAI.
|
||||||
|
|
||||||
|
Dependencies: VectorService.
|
||||||
|
|
||||||
|
Method async generate_response(transcript: List[dict], context_query: str):
|
||||||
|
|
||||||
|
RAG: Call VectorService.semantic_search(context_query).
|
||||||
|
|
||||||
|
LCEL Chain: Define the chain: ChatPromptTemplate | LLM | StrOutputParser().
|
||||||
|
|
||||||
|
Invocation: Use .ainvoke() with formatted context and chat history.
|
||||||
|
|
||||||
|
Termination Check: Check if the response contains [END_SESSION].
|
||||||
|
|
||||||
|
Return the response text and a boolean indicating if the session should end.
|
||||||
|
|
||||||
|
Task 4.2: Synthesizer Agent (src/services/agents/synthesizer.py)
|
||||||
|
|
||||||
|
Initialize the Synthesis LLM (Gemini Pro).
|
||||||
|
|
||||||
|
Dependencies: VectorService.
|
||||||
|
|
||||||
|
Method async segment_transcript(transcript: List[dict]) -> SegmentationResult:
|
||||||
|
|
||||||
|
Define the structured chain: ChatPromptTemplate | LLM.with_structured_output(schema=SegmentationResult).
|
||||||
|
|
||||||
|
Invoke asynchronously and return the result.
|
||||||
|
|
||||||
|
Method async generate_links(new_note: Note, neighbors: List[Document]) -> LinkingResult:
|
||||||
|
|
||||||
|
Define the structured chain for linking using LinkingResult.
|
||||||
|
|
||||||
|
Format new_note and neighbors for the prompt.
|
||||||
|
|
||||||
|
Invoke asynchronously and return the result.
|
||||||
|
|
||||||
|
Phase 5: Orchestration and API
|
||||||
|
Task 5.1: Session Orchestration Service (src/services/session_service.py)
|
||||||
|
This service manages the lifecycle and coordinates agents and repositories.
|
||||||
|
|
||||||
|
Dependencies: Repositories, Agents, VectorService.
|
||||||
|
|
||||||
|
async start_session(topic: str): Create session in DB. Call InterviewerAgent for the first message. Update transcript.
|
||||||
|
|
||||||
|
async handle_message(session_id: uuid.UUID, message: str): Get session. Update transcript with user message. Call InterviewerAgent. Update transcript with AI response. If [END_SESSION] detected, update status to processing.
|
||||||
|
|
||||||
|
async process_session_background_task(session_id: uuid.UUID): (The main synthesis pipeline)
|
||||||
|
|
||||||
|
Get session transcript.
|
||||||
|
|
||||||
|
Call SynthesizerAgent.segment_transcript.
|
||||||
|
|
||||||
|
Persistence (SQLite): Save new notes from SegmentationResult using NoteRepository.
|
||||||
|
|
||||||
|
Indexing (ChromaDB): Add the newly created notes using VectorService.add_notes.
|
||||||
|
|
||||||
|
Linking Loop:
|
||||||
|
|
||||||
|
For each new note, call VectorService.semantic_search to find neighbors.
|
||||||
|
|
||||||
|
Call SynthesizerAgent.generate_links.
|
||||||
|
|
||||||
|
Save links from LinkingResult using LinkRepository.
|
||||||
|
|
||||||
|
Finalize: Update session status to completed (or failed on error).
|
||||||
|
|
||||||
|
Task 5.2: API Endpoints (src/main.py)
|
||||||
|
|
||||||
|
Initialize FastAPI app with a lifespan manager to call init_db().
|
||||||
|
|
||||||
|
Implement dependency injection for the DB session and SessionService.
|
||||||
|
|
||||||
|
POST /sessions.start: Takes StartSessionRequest, calls SessionService.start_session.
|
||||||
|
|
||||||
|
POST /sessions.sendMessage: Takes SendMessageRequest. Calls SessionService.handle_message. If the service indicates the session is ending, trigger the background task.
|
||||||
|
|
||||||
|
Background Task Trigger: Use FastAPI's BackgroundTasks to run SessionService.process_session_background_task(session_id). This is critical for returning an immediate API response.
|
||||||
|
|
||||||
|
GET /sessions.getStatus: Takes a session_id query parameter and returns the session's current status from the DB.
|
||||||
9
pytest.ini
Normal file
9
pytest.ini
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
[tool:pytest]
|
||||||
|
asyncio_mode = auto
|
||||||
|
testpaths = tests
|
||||||
|
python_files = test_*.py
|
||||||
|
python_classes = Test*
|
||||||
|
python_functions = test_*
|
||||||
|
addopts = -v --tb=short
|
||||||
|
markers =
|
||||||
|
asyncio: mark test as an async test
|
||||||
13
requirements.txt
Normal file
13
requirements.txt
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
fastapi
|
||||||
|
uvicorn[standard]
|
||||||
|
pydantic-settings
|
||||||
|
sqlmodel
|
||||||
|
aiosqlite
|
||||||
|
langchain
|
||||||
|
langchain-google-genai
|
||||||
|
langchain-community
|
||||||
|
chromadb
|
||||||
|
tiktoken
|
||||||
|
pytest
|
||||||
|
pytest-asyncio
|
||||||
|
httpx
|
||||||
132
run_tests.py
Normal file
132
run_tests.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""
|
||||||
|
Test runner script for SkyTalk API.
|
||||||
|
|
||||||
|
This script can be run without pytest to validate basic functionality.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add the app directory to Python path
|
||||||
|
app_dir = Path(__file__).parent
|
||||||
|
sys.path.insert(0, str(app_dir))
|
||||||
|
|
||||||
|
|
||||||
|
async def test_basic_setup():
|
||||||
|
"""Test basic setup and imports."""
|
||||||
|
print("🧪 Testing basic setup...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.core.config import settings
|
||||||
|
print(f"✅ Config loaded: API key {'set' if settings.GOOGLE_API_KEY and settings.GOOGLE_API_KEY != 'your_google_api_key_here' else 'NOT set'}")
|
||||||
|
|
||||||
|
from app.data.models import Session, Note, Link, StartSessionRequest
|
||||||
|
print("✅ All models imported successfully")
|
||||||
|
|
||||||
|
from app.data.database import init_db, get_session
|
||||||
|
print("✅ Database modules imported successfully")
|
||||||
|
|
||||||
|
from app.data.repositories import SessionRepository, NoteRepository, LinkRepository
|
||||||
|
print("✅ Repository modules imported successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Import error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def test_model_validation():
|
||||||
|
"""Test model creation and validation."""
|
||||||
|
print("\n🧪 Testing model validation...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.data.models import StartSessionRequest, Session, Note
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
# Test API models
|
||||||
|
req = StartSessionRequest(topic="AI Ethics")
|
||||||
|
print(f"✅ StartSessionRequest: {req.topic}")
|
||||||
|
|
||||||
|
# Test database models
|
||||||
|
session = Session()
|
||||||
|
print(f"✅ Session created with ID: {session.id}")
|
||||||
|
|
||||||
|
note = Note(
|
||||||
|
title="Test Note",
|
||||||
|
content="Test content",
|
||||||
|
tags=["test"],
|
||||||
|
session_id=session.id
|
||||||
|
)
|
||||||
|
print(f"✅ Note created with ID: {note.id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Model validation error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def test_database_operations():
|
||||||
|
"""Test database initialization and basic operations."""
|
||||||
|
print("\n🧪 Testing database operations...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.data.database import init_db, get_session
|
||||||
|
from app.data.repositories import SessionRepository
|
||||||
|
from app.data.models import Session
|
||||||
|
|
||||||
|
# Initialize database
|
||||||
|
await init_db()
|
||||||
|
print("✅ Database tables created")
|
||||||
|
|
||||||
|
# Test repository operations
|
||||||
|
async with get_session() as db:
|
||||||
|
repo = SessionRepository(db)
|
||||||
|
session = await repo.create(status="test")
|
||||||
|
print(f"✅ Session created via repository: {session.id}")
|
||||||
|
|
||||||
|
retrieved = await repo.get(session.id)
|
||||||
|
print(f"✅ Session retrieved: {retrieved.status}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Database operation error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Run all tests."""
|
||||||
|
print("🚀 Running SkyTalk API validation tests...\n")
|
||||||
|
|
||||||
|
tests = [
|
||||||
|
test_basic_setup,
|
||||||
|
test_model_validation,
|
||||||
|
test_database_operations
|
||||||
|
]
|
||||||
|
|
||||||
|
passed = 0
|
||||||
|
total = len(tests)
|
||||||
|
|
||||||
|
for test in tests:
|
||||||
|
if await test():
|
||||||
|
passed += 1
|
||||||
|
else:
|
||||||
|
break # Stop on first failure
|
||||||
|
|
||||||
|
print(f"\n📊 Test Results: {passed}/{total} passed")
|
||||||
|
|
||||||
|
if passed == total:
|
||||||
|
print("🎉 All basic validation tests passed!")
|
||||||
|
print("\nNext steps:")
|
||||||
|
print("1. Run full test suite: pytest")
|
||||||
|
print("2. Continue with Phase 3 implementation")
|
||||||
|
else:
|
||||||
|
print("❌ Some tests failed. Please fix issues before continuing.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
24
tests/test_config.py
Normal file
24
tests/test_config.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
"""Test configuration and environment setup."""
|
||||||
|
import pytest
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_loads():
|
||||||
|
"""Test that configuration loads successfully."""
|
||||||
|
assert settings.GOOGLE_API_KEY is not None
|
||||||
|
assert settings.DATABASE_URL.startswith("sqlite+aiosqlite://")
|
||||||
|
assert settings.CHROMA_PERSIST_DIR == "./chroma_db"
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_names():
|
||||||
|
"""Test that model names are configured correctly."""
|
||||||
|
assert "gemini" in settings.LLM_FLASH_MODEL.lower()
|
||||||
|
assert "gemini" in settings.LLM_PRO_MODEL.lower()
|
||||||
|
assert "embedding" in settings.EMBEDDING_MODEL.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_settings():
|
||||||
|
"""Test API configuration."""
|
||||||
|
assert isinstance(settings.API_PORT, int)
|
||||||
|
assert settings.API_PORT > 0
|
||||||
|
assert settings.API_HOST is not None
|
||||||
73
tests/test_database.py
Normal file
73
tests/test_database.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
"""Test database initialization and basic operations."""
|
||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from sqlmodel import SQLModel
|
||||||
|
|
||||||
|
from app.data.database import init_db, get_session, engine
|
||||||
|
from app.data.models import Session, Note, Link
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatabase:
|
||||||
|
"""Test database operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_init_db(self):
|
||||||
|
"""Test database initialization creates tables."""
|
||||||
|
await init_db()
|
||||||
|
|
||||||
|
# Check that tables exist by attempting to create a session
|
||||||
|
async with get_session() as db:
|
||||||
|
session = Session(status="test")
|
||||||
|
db.add(session)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(session)
|
||||||
|
assert session.id is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_context_manager(self):
|
||||||
|
"""Test database session context manager."""
|
||||||
|
await init_db()
|
||||||
|
|
||||||
|
# Test successful transaction
|
||||||
|
async with get_session() as db:
|
||||||
|
session = Session(status="test")
|
||||||
|
db.add(session)
|
||||||
|
# Context manager should auto-commit
|
||||||
|
|
||||||
|
# Verify data was saved
|
||||||
|
async with get_session() as db:
|
||||||
|
from sqlmodel import select
|
||||||
|
statement = select(Session).where(Session.status == "test")
|
||||||
|
result = await db.exec(statement)
|
||||||
|
found_session = result.first()
|
||||||
|
assert found_session is not None
|
||||||
|
assert found_session.status == "test"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_rollback_on_error(self):
|
||||||
|
"""Test that session rolls back on errors."""
|
||||||
|
await init_db()
|
||||||
|
|
||||||
|
initial_count = 0
|
||||||
|
async with get_session() as db:
|
||||||
|
from sqlmodel import select, func
|
||||||
|
result = await db.exec(select(func.count(Session.id)))
|
||||||
|
initial_count = result.one()
|
||||||
|
|
||||||
|
# Simulate an error during transaction
|
||||||
|
try:
|
||||||
|
async with get_session() as db:
|
||||||
|
session = Session(status="test_rollback")
|
||||||
|
db.add(session)
|
||||||
|
# Force an error
|
||||||
|
raise ValueError("Simulated error")
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Verify data was not saved due to rollback
|
||||||
|
async with get_session() as db:
|
||||||
|
from sqlmodel import select, func
|
||||||
|
result = await db.exec(select(func.count(Session.id)))
|
||||||
|
final_count = result.one()
|
||||||
|
assert final_count == initial_count # No new records
|
||||||
131
tests/test_models.py
Normal file
131
tests/test_models.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""Test data models and validation."""
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from app.data.models import (
|
||||||
|
StartSessionRequest,
|
||||||
|
SendMessageRequest,
|
||||||
|
SessionResponse,
|
||||||
|
SessionStatusResponse,
|
||||||
|
RawZettel,
|
||||||
|
SegmentationResult,
|
||||||
|
RawLink,
|
||||||
|
LinkingResult,
|
||||||
|
Note,
|
||||||
|
Session,
|
||||||
|
Link
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAPIModels:
|
||||||
|
"""Test API request/response models."""
|
||||||
|
|
||||||
|
def test_start_session_request(self):
|
||||||
|
"""Test StartSessionRequest validation."""
|
||||||
|
req = StartSessionRequest(topic="AI Ethics")
|
||||||
|
assert req.topic == "AI Ethics"
|
||||||
|
|
||||||
|
# Test empty topic fails
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
StartSessionRequest(topic="")
|
||||||
|
|
||||||
|
def test_send_message_request(self):
|
||||||
|
"""Test SendMessageRequest validation."""
|
||||||
|
session_id = uuid.uuid4()
|
||||||
|
req = SendMessageRequest(session_id=session_id, message="Hello")
|
||||||
|
assert req.session_id == session_id
|
||||||
|
assert req.message == "Hello"
|
||||||
|
|
||||||
|
# Test invalid UUID fails
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
SendMessageRequest(session_id="invalid", message="Hello")
|
||||||
|
|
||||||
|
def test_session_response(self):
|
||||||
|
"""Test SessionResponse model."""
|
||||||
|
session_id = uuid.uuid4()
|
||||||
|
resp = SessionResponse(
|
||||||
|
session_id=session_id,
|
||||||
|
status="active",
|
||||||
|
message="Welcome!"
|
||||||
|
)
|
||||||
|
assert resp.session_id == session_id
|
||||||
|
assert resp.status == "active"
|
||||||
|
assert resp.message == "Welcome!"
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMModels:
|
||||||
|
"""Test LLM structured output models."""
|
||||||
|
|
||||||
|
def test_raw_zettel(self):
|
||||||
|
"""Test RawZettel model."""
|
||||||
|
zettel = RawZettel(
|
||||||
|
title="Test Concept",
|
||||||
|
content="This is a test concept about AI.",
|
||||||
|
tags=["ai", "test", "concept"]
|
||||||
|
)
|
||||||
|
assert zettel.title == "Test Concept"
|
||||||
|
assert len(zettel.tags) == 3
|
||||||
|
assert "ai" in zettel.tags
|
||||||
|
|
||||||
|
def test_segmentation_result(self):
|
||||||
|
"""Test SegmentationResult model."""
|
||||||
|
zettels = [
|
||||||
|
RawZettel(title="Concept 1", content="Content 1", tags=["tag1"]),
|
||||||
|
RawZettel(title="Concept 2", content="Content 2", tags=["tag2"])
|
||||||
|
]
|
||||||
|
result = SegmentationResult(notes=zettels)
|
||||||
|
assert len(result.notes) == 2
|
||||||
|
assert result.notes[0].title == "Concept 1"
|
||||||
|
|
||||||
|
def test_raw_link(self):
|
||||||
|
"""Test RawLink model."""
|
||||||
|
link = RawLink(
|
||||||
|
target_note_id=str(uuid.uuid4()),
|
||||||
|
relationship_context="Related through shared AI concepts"
|
||||||
|
)
|
||||||
|
assert uuid.UUID(link.target_note_id) # Should not raise
|
||||||
|
assert "AI" in link.relationship_context
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatabaseModels:
|
||||||
|
"""Test database models."""
|
||||||
|
|
||||||
|
def test_session_creation(self):
|
||||||
|
"""Test Session model creation."""
|
||||||
|
session = Session()
|
||||||
|
assert isinstance(session.id, uuid.UUID)
|
||||||
|
assert session.status == "active"
|
||||||
|
assert session.transcript == []
|
||||||
|
assert isinstance(session.created_at, datetime)
|
||||||
|
|
||||||
|
def test_note_creation(self):
|
||||||
|
"""Test Note model creation."""
|
||||||
|
session_id = uuid.uuid4()
|
||||||
|
note = Note(
|
||||||
|
title="Test Note",
|
||||||
|
content="This is test content",
|
||||||
|
tags=["test", "note"],
|
||||||
|
session_id=session_id
|
||||||
|
)
|
||||||
|
assert isinstance(note.id, uuid.UUID)
|
||||||
|
assert note.title == "Test Note"
|
||||||
|
assert note.content == "This is test content"
|
||||||
|
assert note.tags == ["test", "note"]
|
||||||
|
assert note.session_id == session_id
|
||||||
|
assert isinstance(note.created_at, datetime)
|
||||||
|
|
||||||
|
def test_link_creation(self):
|
||||||
|
"""Test Link model creation."""
|
||||||
|
source_id = uuid.uuid4()
|
||||||
|
target_id = uuid.uuid4()
|
||||||
|
link = Link(
|
||||||
|
context="These concepts are related through...",
|
||||||
|
source_id=source_id,
|
||||||
|
target_id=target_id
|
||||||
|
)
|
||||||
|
assert link.context == "These concepts are related through..."
|
||||||
|
assert link.source_id == source_id
|
||||||
|
assert link.target_id == target_id
|
||||||
200
tests/test_repositories.py
Normal file
200
tests/test_repositories.py
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
"""Test repository pattern implementations."""
|
||||||
|
import uuid
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.data.database import init_db, get_session
|
||||||
|
from app.data.repositories import SessionRepository, NoteRepository, LinkRepository
|
||||||
|
from app.data.models import Session, Note, Link
|
||||||
|
|
||||||
|
|
||||||
|
class TestSessionRepository:
|
||||||
|
"""Test SessionRepository operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_session(self):
|
||||||
|
"""Test creating a new session."""
|
||||||
|
await init_db()
|
||||||
|
|
||||||
|
async with get_session() as db:
|
||||||
|
repo = SessionRepository(db)
|
||||||
|
session = await repo.create(status="active")
|
||||||
|
|
||||||
|
assert session.id is not None
|
||||||
|
assert session.status == "active"
|
||||||
|
assert session.transcript == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_session(self):
|
||||||
|
"""Test retrieving a session by ID."""
|
||||||
|
await init_db()
|
||||||
|
|
||||||
|
# Create a session
|
||||||
|
session_id = None
|
||||||
|
async with get_session() as db:
|
||||||
|
repo = SessionRepository(db)
|
||||||
|
session = await repo.create(status="test")
|
||||||
|
session_id = session.id
|
||||||
|
|
||||||
|
# Retrieve the session
|
||||||
|
async with get_session() as db:
|
||||||
|
repo = SessionRepository(db)
|
||||||
|
retrieved = await repo.get(session_id)
|
||||||
|
|
||||||
|
assert retrieved is not None
|
||||||
|
assert retrieved.id == session_id
|
||||||
|
assert retrieved.status == "test"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_status(self):
|
||||||
|
"""Test updating session status."""
|
||||||
|
await init_db()
|
||||||
|
|
||||||
|
# Create a session
|
||||||
|
session_id = None
|
||||||
|
async with get_session() as db:
|
||||||
|
repo = SessionRepository(db)
|
||||||
|
session = await repo.create(status="active")
|
||||||
|
session_id = session.id
|
||||||
|
|
||||||
|
# Update status
|
||||||
|
async with get_session() as db:
|
||||||
|
repo = SessionRepository(db)
|
||||||
|
updated = await repo.update_status(session_id, "completed")
|
||||||
|
|
||||||
|
assert updated is not None
|
||||||
|
assert updated.status == "completed"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_append_transcript(self):
|
||||||
|
"""Test appending to session transcript."""
|
||||||
|
await init_db()
|
||||||
|
|
||||||
|
# Create a session
|
||||||
|
session_id = None
|
||||||
|
async with get_session() as db:
|
||||||
|
repo = SessionRepository(db)
|
||||||
|
session = await repo.create()
|
||||||
|
session_id = session.id
|
||||||
|
|
||||||
|
# Append to transcript
|
||||||
|
message = {"role": "user", "content": "Hello"}
|
||||||
|
async with get_session() as db:
|
||||||
|
repo = SessionRepository(db)
|
||||||
|
updated = await repo.append_transcript(session_id, message)
|
||||||
|
|
||||||
|
assert updated is not None
|
||||||
|
assert len(updated.transcript) == 1
|
||||||
|
assert updated.transcript[0] == message
|
||||||
|
|
||||||
|
|
||||||
|
class TestNoteRepository:
|
||||||
|
"""Test NoteRepository operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_bulk_notes(self):
|
||||||
|
"""Test creating multiple notes at once."""
|
||||||
|
await init_db()
|
||||||
|
|
||||||
|
# Create a session first
|
||||||
|
session_id = None
|
||||||
|
async with get_session() as db:
|
||||||
|
session_repo = SessionRepository(db)
|
||||||
|
session = await session_repo.create()
|
||||||
|
session_id = session.id
|
||||||
|
|
||||||
|
# Create notes
|
||||||
|
notes = [
|
||||||
|
Note(
|
||||||
|
title="Note 1",
|
||||||
|
content="Content 1",
|
||||||
|
tags=["tag1"],
|
||||||
|
session_id=session_id
|
||||||
|
),
|
||||||
|
Note(
|
||||||
|
title="Note 2",
|
||||||
|
content="Content 2",
|
||||||
|
tags=["tag2"],
|
||||||
|
session_id=session_id
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
async with get_session() as db:
|
||||||
|
repo = NoteRepository(db)
|
||||||
|
created_notes = await repo.create_bulk(notes)
|
||||||
|
|
||||||
|
assert len(created_notes) == 2
|
||||||
|
assert all(note.id is not None for note in created_notes)
|
||||||
|
assert created_notes[0].title == "Note 1"
|
||||||
|
assert created_notes[1].title == "Note 2"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_session(self):
|
||||||
|
"""Test retrieving notes by session ID."""
|
||||||
|
await init_db()
|
||||||
|
|
||||||
|
# Create session and notes
|
||||||
|
session_id = None
|
||||||
|
async with get_session() as db:
|
||||||
|
session_repo = SessionRepository(db)
|
||||||
|
session = await session_repo.create()
|
||||||
|
session_id = session.id
|
||||||
|
|
||||||
|
notes = [
|
||||||
|
Note(title="Note 1", content="Content 1", tags=["tag1"], session_id=session_id),
|
||||||
|
Note(title="Note 2", content="Content 2", tags=["tag2"], session_id=session_id)
|
||||||
|
]
|
||||||
|
|
||||||
|
note_repo = NoteRepository(db)
|
||||||
|
await note_repo.create_bulk(notes)
|
||||||
|
|
||||||
|
# Retrieve notes
|
||||||
|
async with get_session() as db:
|
||||||
|
repo = NoteRepository(db)
|
||||||
|
session_notes = await repo.get_by_session(session_id)
|
||||||
|
|
||||||
|
assert len(session_notes) == 2
|
||||||
|
titles = [note.title for note in session_notes]
|
||||||
|
assert "Note 1" in titles
|
||||||
|
assert "Note 2" in titles
|
||||||
|
|
||||||
|
|
||||||
|
class TestLinkRepository:
|
||||||
|
"""Test LinkRepository operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_bulk_links(self):
|
||||||
|
"""Test creating multiple links at once."""
|
||||||
|
await init_db()
|
||||||
|
|
||||||
|
# Create session and notes first
|
||||||
|
note_ids = []
|
||||||
|
async with get_session() as db:
|
||||||
|
session_repo = SessionRepository(db)
|
||||||
|
session = await session_repo.create()
|
||||||
|
|
||||||
|
notes = [
|
||||||
|
Note(title="Note 1", content="Content 1", tags=["tag1"], session_id=session.id),
|
||||||
|
Note(title="Note 2", content="Content 2", tags=["tag2"], session_id=session.id)
|
||||||
|
]
|
||||||
|
|
||||||
|
note_repo = NoteRepository(db)
|
||||||
|
created_notes = await note_repo.create_bulk(notes)
|
||||||
|
note_ids = [note.id for note in created_notes]
|
||||||
|
|
||||||
|
# Create links
|
||||||
|
links = [
|
||||||
|
Link(
|
||||||
|
context="These notes are related because...",
|
||||||
|
source_id=note_ids[0],
|
||||||
|
target_id=note_ids[1]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
async with get_session() as db:
|
||||||
|
repo = LinkRepository(db)
|
||||||
|
created_links = await repo.create_bulk(links)
|
||||||
|
|
||||||
|
assert len(created_links) == 1
|
||||||
|
assert created_links[0].id is not None
|
||||||
|
assert created_links[0].source_id == note_ids[0]
|
||||||
|
assert created_links[0].target_id == note_ids[1]
|
||||||
Reference in New Issue
Block a user