feat: implement data layer with comprehensive test infrastructure
- Define SQLModel schemas for Session, Note, and Link entities - Add API request/response models for RPC endpoints - Create LLM structured output models for Zettel extraction - Set up async database initialization with SQLModel and aiosqlite - Implement repository pattern for CRUD operations - Add complete test suite with pytest configuration - Create validation test runner for development workflow - Add .gitignore for Python/FastAPI project security
This commit is contained in:
12
.claude/settings.local.json
Normal file
12
.claude/settings.local.json
Normal file
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(mkdir:*)",
|
||||
"Bash(python run_tests.py:*)",
|
||||
"Bash(source:*)",
|
||||
"Bash(pytest:*)"
|
||||
],
|
||||
"deny": [],
|
||||
"ask": []
|
||||
}
|
||||
}
|
||||
104
.gitignore
vendored
Normal file
104
.gitignore
vendored
Normal file
@@ -0,0 +1,104 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Virtual environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# IDE specific files
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# OS specific files
|
||||
.DS_Store
|
||||
.DS_Store?
|
||||
._*
|
||||
.Spotlight-V100
|
||||
.Trashes
|
||||
ehthumbs.db
|
||||
Thumbs.db
|
||||
|
||||
# Project specific files
|
||||
*.db
|
||||
*.sqlite
|
||||
*.sqlite3
|
||||
skytalk.db*
|
||||
|
||||
# ChromaDB vector database
|
||||
chroma_db/
|
||||
*.chroma
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
logs/
|
||||
|
||||
# Environment variables (keep template)
|
||||
.env.local
|
||||
.env.production
|
||||
.env.development
|
||||
|
||||
# FastAPI specific
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Jupyter Notebooks
|
||||
.ipynb_checkpoints
|
||||
*.ipynb
|
||||
|
||||
# Documentation builds
|
||||
docs/_build/
|
||||
site/
|
||||
116
CLAUDE.md
Normal file
116
CLAUDE.md
Normal file
@@ -0,0 +1,116 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with
|
||||
code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
SkyTalk API is an AI-powered backend service for conversational idea exploration
|
||||
and knowledge synthesis. It interviews users to explore ideas and synthesizes
|
||||
conversations into a structured, semantically-linked knowledge base using the
|
||||
Zettelkasten method.
|
||||
|
||||
## Core Development Commands
|
||||
|
||||
### Environment Setup
|
||||
|
||||
```bash
|
||||
# Install dependencies using uv package manager
|
||||
uv pip install -r requirements.txt
|
||||
|
||||
# Run the API server (async FastAPI)
|
||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
### Code Quality
|
||||
|
||||
```bash
|
||||
# Format code with black (mandatory)
|
||||
black .
|
||||
|
||||
# Type checking
|
||||
mypy .
|
||||
|
||||
# Linting
|
||||
ruff check .
|
||||
```
|
||||
|
||||
## Architecture & Implementation Standards
|
||||
|
||||
### Three-Layer Architecture
|
||||
|
||||
1. **API Layer**: FastAPI RPC-style endpoints (`/sessions/start`,
|
||||
`/sessions/message`)
|
||||
2. **Services Layer**: Orchestration logic with LangChain LCEL
|
||||
3. **Data Layer**: SQLModel (SQLite) + ChromaDB (embedded, persisted)
|
||||
|
||||
### Critical Implementation Rules
|
||||
|
||||
**Database Models**:
|
||||
|
||||
- SQLModel is the single source of truth for all database schemas
|
||||
- ChromaDB runs in embedded mode, persisted to disk
|
||||
- All database operations must be async
|
||||
|
||||
**AI Integration**:
|
||||
|
||||
- **Interviewing**: Use `gemini-2.5-flash-latest` (optimized for speed)
|
||||
- **Synthesis/Linking**: Use `gemini-2.5-pro-latest` (optimized for reasoning)
|
||||
- **Embeddings**: Use `models/text-embedding-004`
|
||||
- **Structured Output**: Always use `.with_structured_output()` with Pydantic
|
||||
models for data extraction - never parse raw text
|
||||
|
||||
**Code Standards**:
|
||||
|
||||
- Maximum 400 lines per Python file
|
||||
- Full type hints required (Pydantic V2 for API, SQLModel for DB)
|
||||
- All I/O operations must use async/await
|
||||
- Configuration via environment variables with `pydantic-settings`
|
||||
- Use `HTTPException` for client errors
|
||||
|
||||
### Project Structure
|
||||
|
||||
```
|
||||
api/
|
||||
├── app/
|
||||
│ ├── api/ # FastAPI endpoints (RPC-style)
|
||||
│ ├── services/ # Business logic, LangChain agents
|
||||
│ │ ├── interviewer.py
|
||||
│ │ ├── synthesizer.py
|
||||
│ │ └── vector.py
|
||||
│ ├── data/ # Repositories and database models
|
||||
│ │ ├── models/ # SQLModel definitions
|
||||
│ │ └── repositories/
|
||||
│ ├── core/ # Configuration, prompts
|
||||
│ └── main.py # FastAPI app initialization
|
||||
├── requirements.txt # Dependencies managed via uv
|
||||
└── .env # Environment variables
|
||||
```
|
||||
|
||||
### Key Implementation Patterns
|
||||
|
||||
**LangChain LCEL Pipeline**:
|
||||
|
||||
```python
|
||||
chain = prompt | llm.with_structured_output(OutputModel) | parser
|
||||
result = await chain.ainvoke({"input": data})
|
||||
```
|
||||
|
||||
**Async Database Operations**:
|
||||
|
||||
```python
|
||||
async def get_session(session_id: str) -> Session:
|
||||
async with get_session() as db:
|
||||
return await db.get(Session, session_id)
|
||||
```
|
||||
|
||||
**Background Task for Synthesis**:
|
||||
|
||||
```python
|
||||
background_tasks.add_task(synthesize_session, session_id)
|
||||
```
|
||||
|
||||
### Overall Style
|
||||
|
||||
- NO sycophancy -- push back/suggest alternative routes when it would help
|
||||
improve the project
|
||||
57
README.md
57
README.md
@@ -81,3 +81,60 @@ SkyTalkAPI: {
|
||||
# Connections to External APIs
|
||||
SkyTalkAPI.Services -> ExternalAPIs.Gemini
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## MVP Features (Phase 1)
|
||||
|
||||
At the completion of this implementation plan, the SkyTalk API will support the
|
||||
following core features:
|
||||
|
||||
- **Session Management:** Start a new interview session based on an initial
|
||||
topic.
|
||||
|
||||
- **RAG-Powered Interviewing:** Engage in a back-and-forth conversation where
|
||||
the AI's questions are informed by existing knowledge in the vector store.
|
||||
|
||||
- **Automatic Session Termination:** The AI can detect a natural conclusion to
|
||||
the conversation.
|
||||
|
||||
- **Asynchronous Synthesis:** Once the interview ends, a background process is
|
||||
triggered to analyze the transcript.
|
||||
|
||||
- **Semantic Segmentation:** The transcript is intelligently broken down into
|
||||
atomic "Zettels" (notes), each focusing on a single concept.
|
||||
|
||||
- **Vector Indexing:** Each new note is converted into a vector embedding and
|
||||
stored for future RAG.
|
||||
|
||||
- **Generative Linking:** The system identifies semantically related notes and
|
||||
uses an LLM to generate a rich, contextual link explaining the relationship
|
||||
between them.
|
||||
|
||||
- **Status Tracking:** Endpoints to check the status of a session (active,
|
||||
processing, completed).
|
||||
|
||||
---
|
||||
|
||||
## Next Steps (Post-MVP)
|
||||
|
||||
- **Authentication:** Implement user accounts and authentication (e.g., JWT) to
|
||||
create user-specific knowledge bases.
|
||||
|
||||
- **Frontend Integration:** Build a web-based frontend (e.g., at
|
||||
www.skytalk.app) that consumes this API.
|
||||
|
||||
- **Knowledge Graph Visualization:** Add endpoints to export the note-and-link
|
||||
structure in a format suitable for graph visualization libraries (e.g., D3.js,
|
||||
Vis.js).
|
||||
|
||||
- **Note Editing and Management:** Provide endpoints for users to manually edit,
|
||||
delete, or merge notes.
|
||||
|
||||
- **Advanced Search:** Implement more sophisticated search functionalities
|
||||
beyond simple semantic search, such as filtering by tags or searching within
|
||||
link contexts.
|
||||
|
||||
- **Scalable Infrastructure:** Migrate from SQLite/embedded ChromaDB to a
|
||||
production-grade database (e.g., PostgreSQL with pgvector) and a managed
|
||||
vector database for scalability.
|
||||
|
||||
19
app/core/config.py
Normal file
19
app/core/config.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||
|
||||
GOOGLE_API_KEY: str
|
||||
DATABASE_URL: str = "sqlite+aiosqlite:///./skytalk.db"
|
||||
CHROMA_PERSIST_DIR: str = "./chroma_db"
|
||||
|
||||
LLM_FLASH_MODEL: str = "gemini-2.5-flash-latest"
|
||||
LLM_PRO_MODEL: str = "gemini-2.5-pro-latest"
|
||||
EMBEDDING_MODEL: str = "models/text-embedding-004"
|
||||
|
||||
API_HOST: str = "0.0.0.0"
|
||||
API_PORT: int = 8000
|
||||
|
||||
|
||||
settings = Settings()
|
||||
33
app/data/database.py
Normal file
33
app/data/database.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
echo=False,
|
||||
future=True
|
||||
)
|
||||
|
||||
|
||||
async def init_db() -> None:
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with AsyncSession(engine) as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
73
app/data/models.py
Normal file
73
app/data/models.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Any
|
||||
from sqlmodel import Field, SQLModel, Relationship, JSON, Column
|
||||
|
||||
|
||||
class StartSessionRequest(SQLModel):
|
||||
topic: str = Field(min_length=1, description="Topic for the session")
|
||||
|
||||
|
||||
class SendMessageRequest(SQLModel):
|
||||
session_id: uuid.UUID
|
||||
message: str = Field(min_length=1, description="User message content")
|
||||
|
||||
|
||||
class SessionResponse(SQLModel):
|
||||
session_id: uuid.UUID
|
||||
status: str
|
||||
message: str
|
||||
|
||||
|
||||
class SessionStatusResponse(SQLModel):
|
||||
session_id: uuid.UUID
|
||||
status: str
|
||||
notes_count: int = 0
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class RawZettel(SQLModel):
|
||||
title: str = Field(description="Concise title of the atomic concept.")
|
||||
content: str = Field(description="The synthesized mini-blog post.")
|
||||
tags: List[str] = Field(description="List of relevant conceptual keywords.")
|
||||
|
||||
|
||||
class SegmentationResult(SQLModel):
|
||||
notes: List[RawZettel]
|
||||
|
||||
|
||||
class RawLink(SQLModel):
|
||||
target_note_id: str = Field(description="UUID of the related note.")
|
||||
relationship_context: str = Field(description="Explanation of how the new note relates to the target note.")
|
||||
|
||||
|
||||
class LinkingResult(SQLModel):
|
||||
links: List[RawLink]
|
||||
|
||||
|
||||
class Link(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
context: str
|
||||
|
||||
source_id: uuid.UUID = Field(foreign_key="note.id")
|
||||
target_id: uuid.UUID = Field(foreign_key="note.id")
|
||||
|
||||
|
||||
class Note(SQLModel, table=True):
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
title: str
|
||||
content: str
|
||||
tags: List[str] = Field(sa_column=Column(JSON))
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
session_id: uuid.UUID = Field(foreign_key="session.id")
|
||||
session: Optional["Session"] = Relationship(back_populates="notes")
|
||||
|
||||
|
||||
class Session(SQLModel, table=True):
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
status: str = Field(default="active")
|
||||
transcript: List[dict] = Field(default=[], sa_column=Column(JSON))
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
notes: List["Note"] = Relationship(back_populates="session")
|
||||
98
app/data/repositories.py
Normal file
98
app/data/repositories.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.data.models import Session, Note, Link
|
||||
|
||||
|
||||
class SessionRepository:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def create(self, status: str = "active") -> Session:
|
||||
db_session = Session(status=status)
|
||||
self.session.add(db_session)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(db_session)
|
||||
return db_session
|
||||
|
||||
async def get(self, session_id: uuid.UUID) -> Optional[Session]:
|
||||
statement = select(Session).where(Session.id == session_id)
|
||||
result = await self.session.exec(statement)
|
||||
return result.first()
|
||||
|
||||
async def update_status(self, session_id: uuid.UUID, status: str) -> Optional[Session]:
|
||||
db_session = await self.get(session_id)
|
||||
if db_session:
|
||||
db_session.status = status
|
||||
await self.session.commit()
|
||||
await self.session.refresh(db_session)
|
||||
return db_session
|
||||
|
||||
async def append_transcript(self, session_id: uuid.UUID, message: dict) -> Optional[Session]:
|
||||
db_session = await self.get(session_id)
|
||||
if db_session:
|
||||
transcript = db_session.transcript.copy()
|
||||
transcript.append(message)
|
||||
db_session.transcript = transcript
|
||||
await self.session.commit()
|
||||
await self.session.refresh(db_session)
|
||||
return db_session
|
||||
|
||||
async def get_with_notes(self, session_id: uuid.UUID) -> Optional[Session]:
|
||||
statement = select(Session).where(Session.id == session_id)
|
||||
result = await self.session.exec(statement)
|
||||
session = result.first()
|
||||
if session:
|
||||
notes_statement = select(Note).where(Note.session_id == session_id)
|
||||
notes_result = await self.session.exec(notes_statement)
|
||||
session.notes = notes_result.all()
|
||||
return session
|
||||
|
||||
|
||||
class NoteRepository:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def create_bulk(self, notes: List[Note]) -> List[Note]:
|
||||
for note in notes:
|
||||
self.session.add(note)
|
||||
await self.session.commit()
|
||||
for note in notes:
|
||||
await self.session.refresh(note)
|
||||
return notes
|
||||
|
||||
async def get(self, note_id: uuid.UUID) -> Optional[Note]:
|
||||
statement = select(Note).where(Note.id == note_id)
|
||||
result = await self.session.exec(statement)
|
||||
return result.first()
|
||||
|
||||
async def get_by_session(self, session_id: uuid.UUID) -> List[Note]:
|
||||
statement = select(Note).where(Note.session_id == session_id)
|
||||
result = await self.session.exec(statement)
|
||||
return result.all()
|
||||
|
||||
|
||||
class LinkRepository:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def create_bulk(self, links: List[Link]) -> List[Link]:
|
||||
for link in links:
|
||||
self.session.add(link)
|
||||
await self.session.commit()
|
||||
for link in links:
|
||||
await self.session.refresh(link)
|
||||
return links
|
||||
|
||||
async def get_by_source(self, source_id: uuid.UUID) -> List[Link]:
|
||||
statement = select(Link).where(Link.source_id == source_id)
|
||||
result = await self.session.exec(statement)
|
||||
return result.all()
|
||||
|
||||
async def get_by_target(self, target_id: uuid.UUID) -> List[Link]:
|
||||
statement = select(Link).where(Link.target_id == target_id)
|
||||
result = await self.session.exec(statement)
|
||||
return result.all()
|
||||
241
implementation-plan.txt
Normal file
241
implementation-plan.txt
Normal file
@@ -0,0 +1,241 @@
|
||||
Product Requirements Document (PRD) & Implementation Plan
|
||||
1. Overview
|
||||
SkyTalk (api.skytalk.app) is an API service that functions as an AI interviewer to elicit user ideas, then processes the conversation to generate "Zettels" (atomic mini-blog posts). These Zettels are semantically linked using RAG and "Generative Linking."
|
||||
|
||||
2. Implementation Tasks (Linear Execution Plan)
|
||||
Phase 1: Project Setup and Core Utilities
|
||||
Task 1.1: Project Structure Initialization
|
||||
Create the following directory and file structure:
|
||||
|
||||
Bash
|
||||
|
||||
skytalk-api/
|
||||
├── .venv/
|
||||
├── src/
|
||||
│ ├── config.py
|
||||
│ ├── db/
|
||||
│ │ ├── database.py
|
||||
│ │ └── repositories.py
|
||||
│ ├── services/
|
||||
│ │ ├── agents/
|
||||
│ │ │ ├── interviewer.py
|
||||
│ │ │ ├── synthesizer.py
|
||||
│ │ │ └── prompts.py
|
||||
│ │ ├── session_service.py
|
||||
│ │ └── vector_service.py
|
||||
│ ├── models.py
|
||||
│ └── main.py
|
||||
├── .env
|
||||
└── requirements.txt
|
||||
Task 1.2: Dependencies and Configuration
|
||||
|
||||
Create requirements.txt:
|
||||
|
||||
Plaintext
|
||||
|
||||
fastapi
|
||||
uvicorn[standard]
|
||||
pydantic-settings
|
||||
sqlmodel
|
||||
aiosqlite
|
||||
langchain
|
||||
langchain-google-genai
|
||||
langchain-community
|
||||
chromadb
|
||||
tiktoken
|
||||
Create src/config.py:
|
||||
|
||||
Python
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||
|
||||
GOOGLE_API_KEY: str
|
||||
DATABASE_URL: str = "sqlite+aiosqlite:///./skytalk.db"
|
||||
CHROMA_PERSIST_DIR: str = "./chroma_db"
|
||||
|
||||
LLM_FLASH_MODEL: str = "gemini-2.5-flash-latest"
|
||||
LLM_PRO_MODEL: str = "gemini-2.5-pro-latest"
|
||||
EMBEDDING_MODEL: str = "models/text-embedding-004"
|
||||
|
||||
settings = Settings()
|
||||
Phase 2: Data Models and Persistence
|
||||
Task 2.1: SQLModel Definitions (src/models.py)
|
||||
Define all data structures, combining database models and API schemas.
|
||||
|
||||
Python
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Any
|
||||
from sqlmodel import Field, SQLModel, Relationship, JSON, Column
|
||||
|
||||
# --- API Payloads (Not stored in DB) ---
|
||||
class StartSessionRequest(SQLModel):
|
||||
topic: str
|
||||
|
||||
class SendMessageRequest(SQLModel):
|
||||
session_id: uuid.UUID
|
||||
message: str
|
||||
|
||||
# --- LLM Structured Output Models ---
|
||||
class RawZettel(SQLModel):
|
||||
title: str = Field(description="Concise title of the atomic concept.")
|
||||
content: str = Field(description="The synthesized mini-blog post.")
|
||||
tags: List[str] = Field(description="List of relevant conceptual keywords.")
|
||||
|
||||
class SegmentationResult(SQLModel):
|
||||
notes: List[RawZettel]
|
||||
|
||||
class RawLink(SQLModel):
|
||||
target_note_id: str = Field(description="UUID of the related note.")
|
||||
relationship_context: str = Field(description="Explanation of how the new note relates to the target note.")
|
||||
|
||||
class LinkingResult(SQLModel):
|
||||
links: List[RawLink]
|
||||
|
||||
# --- Database Models ---
|
||||
class Link(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
context: str
|
||||
|
||||
source_id: uuid.UUID = Field(foreign_key="note.id")
|
||||
target_id: uuid.UUID = Field(foreign_key="note.id")
|
||||
|
||||
class Note(SQLModel, table=True):
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
title: str
|
||||
content: str
|
||||
tags: List[str] = Field(sa_column=Column(JSON))
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
session_id: uuid.UUID = Field(foreign_key="session.id")
|
||||
session: "Session" = Relationship(back_populates="notes")
|
||||
|
||||
class Session(SQLModel, table=True):
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
status: str = Field(default="active") # active, processing, completed, failed
|
||||
transcript: List[dict] = Field(default=[], sa_column=Column(JSON))
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
notes: List["Note"] = Relationship(back_populates="session")
|
||||
Task 2.2: Database Initialization (src/db/database.py)
|
||||
|
||||
Implement create_async_engine from sqlmodel.ext.asyncio.session.
|
||||
|
||||
Define AsyncSession and init_db() to create tables on startup.
|
||||
|
||||
Implement a FastAPI dependency get_session() to inject the AsyncSession.
|
||||
|
||||
Task 2.3: Repositories (src/db/repositories.py)
|
||||
Implement async CRUD operations using SQLModel's syntax.
|
||||
|
||||
SessionRepository: create, get, update_status, append_transcript.
|
||||
|
||||
NoteRepository: create_bulk (efficiently inserts multiple notes).
|
||||
|
||||
LinkRepository: create_bulk.
|
||||
|
||||
Phase 3: Core Services
|
||||
Task 3.1: Vector Store Service (src/services/vector_service.py)
|
||||
|
||||
Initialize GoogleGenerativeAIEmbeddings from langchain_google_genai.
|
||||
|
||||
Initialize a persistent Chroma client.
|
||||
|
||||
async add_notes(notes: List[Note]): Adds note content and metadata to Chroma, using note ids as document IDs.
|
||||
|
||||
async semantic_search(query: str, k: int = 5): Performs similarity search and returns LangChain Document objects.
|
||||
|
||||
Task 3.2: Prompts (src/services/agents/prompts.py)
|
||||
|
||||
INTERVIEWER_SYSTEM_PROMPT: Socratic persona. Accepts {retrieved_context}. Instructs AI to ask one question at a time and to output the exact token [END_SESSION] when the conversation naturally concludes.
|
||||
|
||||
SEGMENTATION_PROMPT: Instructs LLM to analyze {transcript} and segment it into atomic ideas, enforcing the SegmentationResult structure.
|
||||
|
||||
GENERATIVE_LINKING_PROMPT: Instructs LLM to compare {new_note} against {neighbors}, enforcing the LinkingResult structure.
|
||||
|
||||
Phase 4: The Agent Layer
|
||||
Task 4.1: Interviewer Agent (src/services/agents/interviewer.py)
|
||||
|
||||
Initialize the Interview LLM (Gemini Flash) using ChatGoogleGenerativeAI.
|
||||
|
||||
Dependencies: VectorService.
|
||||
|
||||
Method async generate_response(transcript: List[dict], context_query: str):
|
||||
|
||||
RAG: Call VectorService.semantic_search(context_query).
|
||||
|
||||
LCEL Chain: Define the chain: ChatPromptTemplate | LLM | StrOutputParser().
|
||||
|
||||
Invocation: Use .ainvoke() with formatted context and chat history.
|
||||
|
||||
Termination Check: Check if the response contains [END_SESSION].
|
||||
|
||||
Return the response text and a boolean indicating if the session should end.
|
||||
|
||||
Task 4.2: Synthesizer Agent (src/services/agents/synthesizer.py)
|
||||
|
||||
Initialize the Synthesis LLM (Gemini Pro).
|
||||
|
||||
Dependencies: VectorService.
|
||||
|
||||
Method async segment_transcript(transcript: List[dict]) -> SegmentationResult:
|
||||
|
||||
Define the structured chain: ChatPromptTemplate | LLM.with_structured_output(schema=SegmentationResult).
|
||||
|
||||
Invoke asynchronously and return the result.
|
||||
|
||||
Method async generate_links(new_note: Note, neighbors: List[Document]) -> LinkingResult:
|
||||
|
||||
Define the structured chain for linking using LinkingResult.
|
||||
|
||||
Format new_note and neighbors for the prompt.
|
||||
|
||||
Invoke asynchronously and return the result.
|
||||
|
||||
Phase 5: Orchestration and API
|
||||
Task 5.1: Session Orchestration Service (src/services/session_service.py)
|
||||
This service manages the lifecycle and coordinates agents and repositories.
|
||||
|
||||
Dependencies: Repositories, Agents, VectorService.
|
||||
|
||||
async start_session(topic: str): Create session in DB. Call InterviewerAgent for the first message. Update transcript.
|
||||
|
||||
async handle_message(session_id: uuid.UUID, message: str): Get session. Update transcript with user message. Call InterviewerAgent. Update transcript with AI response. If [END_SESSION] detected, update status to processing.
|
||||
|
||||
async process_session_background_task(session_id: uuid.UUID): (The main synthesis pipeline)
|
||||
|
||||
Get session transcript.
|
||||
|
||||
Call SynthesizerAgent.segment_transcript.
|
||||
|
||||
Persistence (SQLite): Save new notes from SegmentationResult using NoteRepository.
|
||||
|
||||
Indexing (ChromaDB): Add the newly created notes using VectorService.add_notes.
|
||||
|
||||
Linking Loop:
|
||||
|
||||
For each new note, call VectorService.semantic_search to find neighbors.
|
||||
|
||||
Call SynthesizerAgent.generate_links.
|
||||
|
||||
Save links from LinkingResult using LinkRepository.
|
||||
|
||||
Finalize: Update session status to completed (or failed on error).
|
||||
|
||||
Task 5.2: API Endpoints (src/main.py)
|
||||
|
||||
Initialize FastAPI app with a lifespan manager to call init_db().
|
||||
|
||||
Implement dependency injection for the DB session and SessionService.
|
||||
|
||||
POST /sessions.start: Takes StartSessionRequest, calls SessionService.start_session.
|
||||
|
||||
POST /sessions.sendMessage: Takes SendMessageRequest. Calls SessionService.handle_message. If the service indicates the session is ending, trigger the background task.
|
||||
|
||||
Background Task Trigger: Use FastAPI's BackgroundTasks to run SessionService.process_session_background_task(session_id). This is critical for returning an immediate API response.
|
||||
|
||||
GET /sessions.getStatus: Takes a session_id query parameter and returns the session's current status from the DB.
|
||||
9
pytest.ini
Normal file
9
pytest.ini
Normal file
@@ -0,0 +1,9 @@
|
||||
[tool:pytest]
|
||||
asyncio_mode = auto
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
addopts = -v --tb=short
|
||||
markers =
|
||||
asyncio: mark test as an async test
|
||||
13
requirements.txt
Normal file
13
requirements.txt
Normal file
@@ -0,0 +1,13 @@
|
||||
fastapi
|
||||
uvicorn[standard]
|
||||
pydantic-settings
|
||||
sqlmodel
|
||||
aiosqlite
|
||||
langchain
|
||||
langchain-google-genai
|
||||
langchain-community
|
||||
chromadb
|
||||
tiktoken
|
||||
pytest
|
||||
pytest-asyncio
|
||||
httpx
|
||||
132
run_tests.py
Normal file
132
run_tests.py
Normal file
@@ -0,0 +1,132 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Test runner script for SkyTalk API.
|
||||
|
||||
This script can be run without pytest to validate basic functionality.
|
||||
"""
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the app directory to Python path
|
||||
app_dir = Path(__file__).parent
|
||||
sys.path.insert(0, str(app_dir))
|
||||
|
||||
|
||||
async def test_basic_setup():
|
||||
"""Test basic setup and imports."""
|
||||
print("🧪 Testing basic setup...")
|
||||
|
||||
try:
|
||||
from app.core.config import settings
|
||||
print(f"✅ Config loaded: API key {'set' if settings.GOOGLE_API_KEY and settings.GOOGLE_API_KEY != 'your_google_api_key_here' else 'NOT set'}")
|
||||
|
||||
from app.data.models import Session, Note, Link, StartSessionRequest
|
||||
print("✅ All models imported successfully")
|
||||
|
||||
from app.data.database import init_db, get_session
|
||||
print("✅ Database modules imported successfully")
|
||||
|
||||
from app.data.repositories import SessionRepository, NoteRepository, LinkRepository
|
||||
print("✅ Repository modules imported successfully")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Import error: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def test_model_validation():
|
||||
"""Test model creation and validation."""
|
||||
print("\n🧪 Testing model validation...")
|
||||
|
||||
try:
|
||||
from app.data.models import StartSessionRequest, Session, Note
|
||||
import uuid
|
||||
|
||||
# Test API models
|
||||
req = StartSessionRequest(topic="AI Ethics")
|
||||
print(f"✅ StartSessionRequest: {req.topic}")
|
||||
|
||||
# Test database models
|
||||
session = Session()
|
||||
print(f"✅ Session created with ID: {session.id}")
|
||||
|
||||
note = Note(
|
||||
title="Test Note",
|
||||
content="Test content",
|
||||
tags=["test"],
|
||||
session_id=session.id
|
||||
)
|
||||
print(f"✅ Note created with ID: {note.id}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Model validation error: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def test_database_operations():
|
||||
"""Test database initialization and basic operations."""
|
||||
print("\n🧪 Testing database operations...")
|
||||
|
||||
try:
|
||||
from app.data.database import init_db, get_session
|
||||
from app.data.repositories import SessionRepository
|
||||
from app.data.models import Session
|
||||
|
||||
# Initialize database
|
||||
await init_db()
|
||||
print("✅ Database tables created")
|
||||
|
||||
# Test repository operations
|
||||
async with get_session() as db:
|
||||
repo = SessionRepository(db)
|
||||
session = await repo.create(status="test")
|
||||
print(f"✅ Session created via repository: {session.id}")
|
||||
|
||||
retrieved = await repo.get(session.id)
|
||||
print(f"✅ Session retrieved: {retrieved.status}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Database operation error: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all tests."""
|
||||
print("🚀 Running SkyTalk API validation tests...\n")
|
||||
|
||||
tests = [
|
||||
test_basic_setup,
|
||||
test_model_validation,
|
||||
test_database_operations
|
||||
]
|
||||
|
||||
passed = 0
|
||||
total = len(tests)
|
||||
|
||||
for test in tests:
|
||||
if await test():
|
||||
passed += 1
|
||||
else:
|
||||
break # Stop on first failure
|
||||
|
||||
print(f"\n📊 Test Results: {passed}/{total} passed")
|
||||
|
||||
if passed == total:
|
||||
print("🎉 All basic validation tests passed!")
|
||||
print("\nNext steps:")
|
||||
print("1. Run full test suite: pytest")
|
||||
print("2. Continue with Phase 3 implementation")
|
||||
else:
|
||||
print("❌ Some tests failed. Please fix issues before continuing.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
24
tests/test_config.py
Normal file
24
tests/test_config.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Test configuration and environment setup."""
|
||||
import pytest
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
def test_config_loads():
|
||||
"""Test that configuration loads successfully."""
|
||||
assert settings.GOOGLE_API_KEY is not None
|
||||
assert settings.DATABASE_URL.startswith("sqlite+aiosqlite://")
|
||||
assert settings.CHROMA_PERSIST_DIR == "./chroma_db"
|
||||
|
||||
|
||||
def test_model_names():
|
||||
"""Test that model names are configured correctly."""
|
||||
assert "gemini" in settings.LLM_FLASH_MODEL.lower()
|
||||
assert "gemini" in settings.LLM_PRO_MODEL.lower()
|
||||
assert "embedding" in settings.EMBEDDING_MODEL.lower()
|
||||
|
||||
|
||||
def test_api_settings():
|
||||
"""Test API configuration."""
|
||||
assert isinstance(settings.API_PORT, int)
|
||||
assert settings.API_PORT > 0
|
||||
assert settings.API_HOST is not None
|
||||
73
tests/test_database.py
Normal file
73
tests/test_database.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Test database initialization and basic operations."""
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from app.data.database import init_db, get_session, engine
|
||||
from app.data.models import Session, Note, Link
|
||||
|
||||
|
||||
class TestDatabase:
|
||||
"""Test database operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db(self):
|
||||
"""Test database initialization creates tables."""
|
||||
await init_db()
|
||||
|
||||
# Check that tables exist by attempting to create a session
|
||||
async with get_session() as db:
|
||||
session = Session(status="test")
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
assert session.id is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_context_manager(self):
|
||||
"""Test database session context manager."""
|
||||
await init_db()
|
||||
|
||||
# Test successful transaction
|
||||
async with get_session() as db:
|
||||
session = Session(status="test")
|
||||
db.add(session)
|
||||
# Context manager should auto-commit
|
||||
|
||||
# Verify data was saved
|
||||
async with get_session() as db:
|
||||
from sqlmodel import select
|
||||
statement = select(Session).where(Session.status == "test")
|
||||
result = await db.exec(statement)
|
||||
found_session = result.first()
|
||||
assert found_session is not None
|
||||
assert found_session.status == "test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_rollback_on_error(self):
|
||||
"""Test that session rolls back on errors."""
|
||||
await init_db()
|
||||
|
||||
initial_count = 0
|
||||
async with get_session() as db:
|
||||
from sqlmodel import select, func
|
||||
result = await db.exec(select(func.count(Session.id)))
|
||||
initial_count = result.one()
|
||||
|
||||
# Simulate an error during transaction
|
||||
try:
|
||||
async with get_session() as db:
|
||||
session = Session(status="test_rollback")
|
||||
db.add(session)
|
||||
# Force an error
|
||||
raise ValueError("Simulated error")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Verify data was not saved due to rollback
|
||||
async with get_session() as db:
|
||||
from sqlmodel import select, func
|
||||
result = await db.exec(select(func.count(Session.id)))
|
||||
final_count = result.one()
|
||||
assert final_count == initial_count # No new records
|
||||
131
tests/test_models.py
Normal file
131
tests/test_models.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Test data models and validation."""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.data.models import (
|
||||
StartSessionRequest,
|
||||
SendMessageRequest,
|
||||
SessionResponse,
|
||||
SessionStatusResponse,
|
||||
RawZettel,
|
||||
SegmentationResult,
|
||||
RawLink,
|
||||
LinkingResult,
|
||||
Note,
|
||||
Session,
|
||||
Link
|
||||
)
|
||||
|
||||
|
||||
class TestAPIModels:
|
||||
"""Test API request/response models."""
|
||||
|
||||
def test_start_session_request(self):
|
||||
"""Test StartSessionRequest validation."""
|
||||
req = StartSessionRequest(topic="AI Ethics")
|
||||
assert req.topic == "AI Ethics"
|
||||
|
||||
# Test empty topic fails
|
||||
with pytest.raises(ValidationError):
|
||||
StartSessionRequest(topic="")
|
||||
|
||||
def test_send_message_request(self):
|
||||
"""Test SendMessageRequest validation."""
|
||||
session_id = uuid.uuid4()
|
||||
req = SendMessageRequest(session_id=session_id, message="Hello")
|
||||
assert req.session_id == session_id
|
||||
assert req.message == "Hello"
|
||||
|
||||
# Test invalid UUID fails
|
||||
with pytest.raises(ValidationError):
|
||||
SendMessageRequest(session_id="invalid", message="Hello")
|
||||
|
||||
def test_session_response(self):
|
||||
"""Test SessionResponse model."""
|
||||
session_id = uuid.uuid4()
|
||||
resp = SessionResponse(
|
||||
session_id=session_id,
|
||||
status="active",
|
||||
message="Welcome!"
|
||||
)
|
||||
assert resp.session_id == session_id
|
||||
assert resp.status == "active"
|
||||
assert resp.message == "Welcome!"
|
||||
|
||||
|
||||
class TestLLMModels:
|
||||
"""Test LLM structured output models."""
|
||||
|
||||
def test_raw_zettel(self):
|
||||
"""Test RawZettel model."""
|
||||
zettel = RawZettel(
|
||||
title="Test Concept",
|
||||
content="This is a test concept about AI.",
|
||||
tags=["ai", "test", "concept"]
|
||||
)
|
||||
assert zettel.title == "Test Concept"
|
||||
assert len(zettel.tags) == 3
|
||||
assert "ai" in zettel.tags
|
||||
|
||||
def test_segmentation_result(self):
|
||||
"""Test SegmentationResult model."""
|
||||
zettels = [
|
||||
RawZettel(title="Concept 1", content="Content 1", tags=["tag1"]),
|
||||
RawZettel(title="Concept 2", content="Content 2", tags=["tag2"])
|
||||
]
|
||||
result = SegmentationResult(notes=zettels)
|
||||
assert len(result.notes) == 2
|
||||
assert result.notes[0].title == "Concept 1"
|
||||
|
||||
def test_raw_link(self):
|
||||
"""Test RawLink model."""
|
||||
link = RawLink(
|
||||
target_note_id=str(uuid.uuid4()),
|
||||
relationship_context="Related through shared AI concepts"
|
||||
)
|
||||
assert uuid.UUID(link.target_note_id) # Should not raise
|
||||
assert "AI" in link.relationship_context
|
||||
|
||||
|
||||
class TestDatabaseModels:
|
||||
"""Test database models."""
|
||||
|
||||
def test_session_creation(self):
|
||||
"""Test Session model creation."""
|
||||
session = Session()
|
||||
assert isinstance(session.id, uuid.UUID)
|
||||
assert session.status == "active"
|
||||
assert session.transcript == []
|
||||
assert isinstance(session.created_at, datetime)
|
||||
|
||||
def test_note_creation(self):
|
||||
"""Test Note model creation."""
|
||||
session_id = uuid.uuid4()
|
||||
note = Note(
|
||||
title="Test Note",
|
||||
content="This is test content",
|
||||
tags=["test", "note"],
|
||||
session_id=session_id
|
||||
)
|
||||
assert isinstance(note.id, uuid.UUID)
|
||||
assert note.title == "Test Note"
|
||||
assert note.content == "This is test content"
|
||||
assert note.tags == ["test", "note"]
|
||||
assert note.session_id == session_id
|
||||
assert isinstance(note.created_at, datetime)
|
||||
|
||||
def test_link_creation(self):
|
||||
"""Test Link model creation."""
|
||||
source_id = uuid.uuid4()
|
||||
target_id = uuid.uuid4()
|
||||
link = Link(
|
||||
context="These concepts are related through...",
|
||||
source_id=source_id,
|
||||
target_id=target_id
|
||||
)
|
||||
assert link.context == "These concepts are related through..."
|
||||
assert link.source_id == source_id
|
||||
assert link.target_id == target_id
|
||||
200
tests/test_repositories.py
Normal file
200
tests/test_repositories.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Test repository pattern implementations."""
|
||||
import uuid
|
||||
import pytest
|
||||
|
||||
from app.data.database import init_db, get_session
|
||||
from app.data.repositories import SessionRepository, NoteRepository, LinkRepository
|
||||
from app.data.models import Session, Note, Link
|
||||
|
||||
|
||||
class TestSessionRepository:
|
||||
"""Test SessionRepository operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session(self):
|
||||
"""Test creating a new session."""
|
||||
await init_db()
|
||||
|
||||
async with get_session() as db:
|
||||
repo = SessionRepository(db)
|
||||
session = await repo.create(status="active")
|
||||
|
||||
assert session.id is not None
|
||||
assert session.status == "active"
|
||||
assert session.transcript == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session(self):
|
||||
"""Test retrieving a session by ID."""
|
||||
await init_db()
|
||||
|
||||
# Create a session
|
||||
session_id = None
|
||||
async with get_session() as db:
|
||||
repo = SessionRepository(db)
|
||||
session = await repo.create(status="test")
|
||||
session_id = session.id
|
||||
|
||||
# Retrieve the session
|
||||
async with get_session() as db:
|
||||
repo = SessionRepository(db)
|
||||
retrieved = await repo.get(session_id)
|
||||
|
||||
assert retrieved is not None
|
||||
assert retrieved.id == session_id
|
||||
assert retrieved.status == "test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_status(self):
|
||||
"""Test updating session status."""
|
||||
await init_db()
|
||||
|
||||
# Create a session
|
||||
session_id = None
|
||||
async with get_session() as db:
|
||||
repo = SessionRepository(db)
|
||||
session = await repo.create(status="active")
|
||||
session_id = session.id
|
||||
|
||||
# Update status
|
||||
async with get_session() as db:
|
||||
repo = SessionRepository(db)
|
||||
updated = await repo.update_status(session_id, "completed")
|
||||
|
||||
assert updated is not None
|
||||
assert updated.status == "completed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_transcript(self):
|
||||
"""Test appending to session transcript."""
|
||||
await init_db()
|
||||
|
||||
# Create a session
|
||||
session_id = None
|
||||
async with get_session() as db:
|
||||
repo = SessionRepository(db)
|
||||
session = await repo.create()
|
||||
session_id = session.id
|
||||
|
||||
# Append to transcript
|
||||
message = {"role": "user", "content": "Hello"}
|
||||
async with get_session() as db:
|
||||
repo = SessionRepository(db)
|
||||
updated = await repo.append_transcript(session_id, message)
|
||||
|
||||
assert updated is not None
|
||||
assert len(updated.transcript) == 1
|
||||
assert updated.transcript[0] == message
|
||||
|
||||
|
||||
class TestNoteRepository:
|
||||
"""Test NoteRepository operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_bulk_notes(self):
|
||||
"""Test creating multiple notes at once."""
|
||||
await init_db()
|
||||
|
||||
# Create a session first
|
||||
session_id = None
|
||||
async with get_session() as db:
|
||||
session_repo = SessionRepository(db)
|
||||
session = await session_repo.create()
|
||||
session_id = session.id
|
||||
|
||||
# Create notes
|
||||
notes = [
|
||||
Note(
|
||||
title="Note 1",
|
||||
content="Content 1",
|
||||
tags=["tag1"],
|
||||
session_id=session_id
|
||||
),
|
||||
Note(
|
||||
title="Note 2",
|
||||
content="Content 2",
|
||||
tags=["tag2"],
|
||||
session_id=session_id
|
||||
)
|
||||
]
|
||||
|
||||
async with get_session() as db:
|
||||
repo = NoteRepository(db)
|
||||
created_notes = await repo.create_bulk(notes)
|
||||
|
||||
assert len(created_notes) == 2
|
||||
assert all(note.id is not None for note in created_notes)
|
||||
assert created_notes[0].title == "Note 1"
|
||||
assert created_notes[1].title == "Note 2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_session(self):
|
||||
"""Test retrieving notes by session ID."""
|
||||
await init_db()
|
||||
|
||||
# Create session and notes
|
||||
session_id = None
|
||||
async with get_session() as db:
|
||||
session_repo = SessionRepository(db)
|
||||
session = await session_repo.create()
|
||||
session_id = session.id
|
||||
|
||||
notes = [
|
||||
Note(title="Note 1", content="Content 1", tags=["tag1"], session_id=session_id),
|
||||
Note(title="Note 2", content="Content 2", tags=["tag2"], session_id=session_id)
|
||||
]
|
||||
|
||||
note_repo = NoteRepository(db)
|
||||
await note_repo.create_bulk(notes)
|
||||
|
||||
# Retrieve notes
|
||||
async with get_session() as db:
|
||||
repo = NoteRepository(db)
|
||||
session_notes = await repo.get_by_session(session_id)
|
||||
|
||||
assert len(session_notes) == 2
|
||||
titles = [note.title for note in session_notes]
|
||||
assert "Note 1" in titles
|
||||
assert "Note 2" in titles
|
||||
|
||||
|
||||
class TestLinkRepository:
|
||||
"""Test LinkRepository operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_bulk_links(self):
|
||||
"""Test creating multiple links at once."""
|
||||
await init_db()
|
||||
|
||||
# Create session and notes first
|
||||
note_ids = []
|
||||
async with get_session() as db:
|
||||
session_repo = SessionRepository(db)
|
||||
session = await session_repo.create()
|
||||
|
||||
notes = [
|
||||
Note(title="Note 1", content="Content 1", tags=["tag1"], session_id=session.id),
|
||||
Note(title="Note 2", content="Content 2", tags=["tag2"], session_id=session.id)
|
||||
]
|
||||
|
||||
note_repo = NoteRepository(db)
|
||||
created_notes = await note_repo.create_bulk(notes)
|
||||
note_ids = [note.id for note in created_notes]
|
||||
|
||||
# Create links
|
||||
links = [
|
||||
Link(
|
||||
context="These notes are related because...",
|
||||
source_id=note_ids[0],
|
||||
target_id=note_ids[1]
|
||||
)
|
||||
]
|
||||
|
||||
async with get_session() as db:
|
||||
repo = LinkRepository(db)
|
||||
created_links = await repo.create_bulk(links)
|
||||
|
||||
assert len(created_links) == 1
|
||||
assert created_links[0].id is not None
|
||||
assert created_links[0].source_id == note_ids[0]
|
||||
assert created_links[0].target_id == note_ids[1]
|
||||
Reference in New Issue
Block a user