feat: implement data layer with comprehensive test infrastructure

- Define SQLModel schemas for Session, Note, and Link entities
  - Add API request/response models for RPC endpoints
  - Create LLM structured output models for Zettel extraction
  - Set up async database initialization with SQLModel and aiosqlite
  - Implement repository pattern for CRUD operations
  - Add complete test suite with pytest configuration
  - Create validation test runner for development workflow
  - Add .gitignore for Python/FastAPI project security
This commit is contained in:
Albert
2025-08-17 01:25:16 +00:00
parent 980124be58
commit f60d61a78f
17 changed files with 1335 additions and 0 deletions

View File

@@ -0,0 +1,12 @@
{
"permissions": {
"allow": [
"Bash(mkdir:*)",
"Bash(python run_tests.py:*)",
"Bash(source:*)",
"Bash(pytest:*)"
],
"deny": [],
"ask": []
}
}

104
.gitignore vendored Normal file
View File

@@ -0,0 +1,104 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Virtual environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# IDE specific files
.vscode/
.idea/
*.swp
*.swo
*~
# OS specific files
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db
# Project specific files
*.db
*.sqlite
*.sqlite3
skytalk.db*
# ChromaDB vector database
chroma_db/
*.chroma
# Logs
*.log
logs/
# Environment variables (keep template)
.env.local
.env.production
.env.development
# FastAPI specific
.mypy_cache/
.dmypy.json
dmypy.json
# Jupyter Notebooks
.ipynb_checkpoints
*.ipynb
# Documentation builds
docs/_build/
site/

116
CLAUDE.md Normal file
View File

@@ -0,0 +1,116 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with
code in this repository.
## Project Overview
SkyTalk API is an AI-powered backend service for conversational idea exploration
and knowledge synthesis. It interviews users to explore ideas and synthesizes
conversations into a structured, semantically-linked knowledge base using the
Zettelkasten method.
## Core Development Commands
### Environment Setup
```bash
# Install dependencies using uv package manager
uv pip install -r requirements.txt
# Run the API server (async FastAPI)
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
```
### Code Quality
```bash
# Format code with black (mandatory)
black .
# Type checking
mypy .
# Linting
ruff check .
```
## Architecture & Implementation Standards
### Three-Layer Architecture
1. **API Layer**: FastAPI RPC-style endpoints (`/sessions/start`,
`/sessions/message`)
2. **Services Layer**: Orchestration logic with LangChain LCEL
3. **Data Layer**: SQLModel (SQLite) + ChromaDB (embedded, persisted)
### Critical Implementation Rules
**Database Models**:
- SQLModel is the single source of truth for all database schemas
- ChromaDB runs in embedded mode, persisted to disk
- All database operations must be async
**AI Integration**:
- **Interviewing**: Use `gemini-2.5-flash-latest` (optimized for speed)
- **Synthesis/Linking**: Use `gemini-2.5-pro-latest` (optimized for reasoning)
- **Embeddings**: Use `models/text-embedding-004`
- **Structured Output**: Always use `.with_structured_output()` with Pydantic
models for data extraction - never parse raw text
**Code Standards**:
- Maximum 400 lines per Python file
- Full type hints required (Pydantic V2 for API, SQLModel for DB)
- All I/O operations must use async/await
- Configuration via environment variables with `pydantic-settings`
- Use `HTTPException` for client errors
### Project Structure
```
api/
├── app/
│ ├── api/ # FastAPI endpoints (RPC-style)
│ ├── services/ # Business logic, LangChain agents
│ │ ├── interviewer.py
│ │ ├── synthesizer.py
│ │ └── vector.py
│ ├── data/ # Repositories and database models
│ │ ├── models/ # SQLModel definitions
│ │ └── repositories/
│ ├── core/ # Configuration, prompts
│ └── main.py # FastAPI app initialization
├── requirements.txt # Dependencies managed via uv
└── .env # Environment variables
```
### Key Implementation Patterns
**LangChain LCEL Pipeline**:
```python
chain = prompt | llm.with_structured_output(OutputModel) | parser
result = await chain.ainvoke({"input": data})
```
**Async Database Operations**:
```python
async def get_session(session_id: str) -> Session:
async with get_session() as db:
return await db.get(Session, session_id)
```
**Background Task for Synthesis**:
```python
background_tasks.add_task(synthesize_session, session_id)
```
### Overall Style
- NO sycophancy -- push back/suggest alternative routes when it would help
improve the project

View File

@@ -81,3 +81,60 @@ SkyTalkAPI: {
# Connections to External APIs
SkyTalkAPI.Services -> ExternalAPIs.Gemini
```
---
## MVP Features (Phase 1)
At the completion of this implementation plan, the SkyTalk API will support the
following core features:
- **Session Management:** Start a new interview session based on an initial
topic.
- **RAG-Powered Interviewing:** Engage in a back-and-forth conversation where
the AI's questions are informed by existing knowledge in the vector store.
- **Automatic Session Termination:** The AI can detect a natural conclusion to
the conversation.
- **Asynchronous Synthesis:** Once the interview ends, a background process is
triggered to analyze the transcript.
- **Semantic Segmentation:** The transcript is intelligently broken down into
atomic "Zettels" (notes), each focusing on a single concept.
- **Vector Indexing:** Each new note is converted into a vector embedding and
stored for future RAG.
- **Generative Linking:** The system identifies semantically related notes and
uses an LLM to generate a rich, contextual link explaining the relationship
between them.
- **Status Tracking:** Endpoints to check the status of a session (active,
processing, completed).
---
## Next Steps (Post-MVP)
- **Authentication:** Implement user accounts and authentication (e.g., JWT) to
create user-specific knowledge bases.
- **Frontend Integration:** Build a web-based frontend (e.g., at
www.skytalk.app) that consumes this API.
- **Knowledge Graph Visualization:** Add endpoints to export the note-and-link
structure in a format suitable for graph visualization libraries (e.g., D3.js,
Vis.js).
- **Note Editing and Management:** Provide endpoints for users to manually edit,
delete, or merge notes.
- **Advanced Search:** Implement more sophisticated search functionalities
beyond simple semantic search, such as filtering by tags or searching within
link contexts.
- **Scalable Infrastructure:** Migrate from SQLite/embedded ChromaDB to a
production-grade database (e.g., PostgreSQL with pgvector) and a managed
vector database for scalability.

19
app/core/config.py Normal file
View File

@@ -0,0 +1,19 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
GOOGLE_API_KEY: str
DATABASE_URL: str = "sqlite+aiosqlite:///./skytalk.db"
CHROMA_PERSIST_DIR: str = "./chroma_db"
LLM_FLASH_MODEL: str = "gemini-2.5-flash-latest"
LLM_PRO_MODEL: str = "gemini-2.5-pro-latest"
EMBEDDING_MODEL: str = "models/text-embedding-004"
API_HOST: str = "0.0.0.0"
API_PORT: int = 8000
settings = Settings()

33
app/data/database.py Normal file
View File

@@ -0,0 +1,33 @@
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.config import settings
engine = create_async_engine(
settings.DATABASE_URL,
echo=False,
future=True
)
async def init_db() -> None:
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
@asynccontextmanager
async def get_session() -> AsyncGenerator[AsyncSession, None]:
async with AsyncSession(engine) as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()

73
app/data/models.py Normal file
View File

@@ -0,0 +1,73 @@
import uuid
from datetime import datetime
from typing import List, Optional, Any
from sqlmodel import Field, SQLModel, Relationship, JSON, Column
class StartSessionRequest(SQLModel):
topic: str = Field(min_length=1, description="Topic for the session")
class SendMessageRequest(SQLModel):
session_id: uuid.UUID
message: str = Field(min_length=1, description="User message content")
class SessionResponse(SQLModel):
session_id: uuid.UUID
status: str
message: str
class SessionStatusResponse(SQLModel):
session_id: uuid.UUID
status: str
notes_count: int = 0
created_at: datetime
class RawZettel(SQLModel):
title: str = Field(description="Concise title of the atomic concept.")
content: str = Field(description="The synthesized mini-blog post.")
tags: List[str] = Field(description="List of relevant conceptual keywords.")
class SegmentationResult(SQLModel):
notes: List[RawZettel]
class RawLink(SQLModel):
target_note_id: str = Field(description="UUID of the related note.")
relationship_context: str = Field(description="Explanation of how the new note relates to the target note.")
class LinkingResult(SQLModel):
links: List[RawLink]
class Link(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
context: str
source_id: uuid.UUID = Field(foreign_key="note.id")
target_id: uuid.UUID = Field(foreign_key="note.id")
class Note(SQLModel, table=True):
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
title: str
content: str
tags: List[str] = Field(sa_column=Column(JSON))
created_at: datetime = Field(default_factory=datetime.utcnow)
session_id: uuid.UUID = Field(foreign_key="session.id")
session: Optional["Session"] = Relationship(back_populates="notes")
class Session(SQLModel, table=True):
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
status: str = Field(default="active")
transcript: List[dict] = Field(default=[], sa_column=Column(JSON))
created_at: datetime = Field(default_factory=datetime.utcnow)
notes: List["Note"] = Relationship(back_populates="session")

98
app/data/repositories.py Normal file
View File

@@ -0,0 +1,98 @@
import uuid
from typing import List, Optional
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.data.models import Session, Note, Link
class SessionRepository:
def __init__(self, session: AsyncSession):
self.session = session
async def create(self, status: str = "active") -> Session:
db_session = Session(status=status)
self.session.add(db_session)
await self.session.commit()
await self.session.refresh(db_session)
return db_session
async def get(self, session_id: uuid.UUID) -> Optional[Session]:
statement = select(Session).where(Session.id == session_id)
result = await self.session.exec(statement)
return result.first()
async def update_status(self, session_id: uuid.UUID, status: str) -> Optional[Session]:
db_session = await self.get(session_id)
if db_session:
db_session.status = status
await self.session.commit()
await self.session.refresh(db_session)
return db_session
async def append_transcript(self, session_id: uuid.UUID, message: dict) -> Optional[Session]:
db_session = await self.get(session_id)
if db_session:
transcript = db_session.transcript.copy()
transcript.append(message)
db_session.transcript = transcript
await self.session.commit()
await self.session.refresh(db_session)
return db_session
async def get_with_notes(self, session_id: uuid.UUID) -> Optional[Session]:
statement = select(Session).where(Session.id == session_id)
result = await self.session.exec(statement)
session = result.first()
if session:
notes_statement = select(Note).where(Note.session_id == session_id)
notes_result = await self.session.exec(notes_statement)
session.notes = notes_result.all()
return session
class NoteRepository:
def __init__(self, session: AsyncSession):
self.session = session
async def create_bulk(self, notes: List[Note]) -> List[Note]:
for note in notes:
self.session.add(note)
await self.session.commit()
for note in notes:
await self.session.refresh(note)
return notes
async def get(self, note_id: uuid.UUID) -> Optional[Note]:
statement = select(Note).where(Note.id == note_id)
result = await self.session.exec(statement)
return result.first()
async def get_by_session(self, session_id: uuid.UUID) -> List[Note]:
statement = select(Note).where(Note.session_id == session_id)
result = await self.session.exec(statement)
return result.all()
class LinkRepository:
def __init__(self, session: AsyncSession):
self.session = session
async def create_bulk(self, links: List[Link]) -> List[Link]:
for link in links:
self.session.add(link)
await self.session.commit()
for link in links:
await self.session.refresh(link)
return links
async def get_by_source(self, source_id: uuid.UUID) -> List[Link]:
statement = select(Link).where(Link.source_id == source_id)
result = await self.session.exec(statement)
return result.all()
async def get_by_target(self, target_id: uuid.UUID) -> List[Link]:
statement = select(Link).where(Link.target_id == target_id)
result = await self.session.exec(statement)
return result.all()

241
implementation-plan.txt Normal file
View File

@@ -0,0 +1,241 @@
Product Requirements Document (PRD) & Implementation Plan
1. Overview
SkyTalk (api.skytalk.app) is an API service that functions as an AI interviewer to elicit user ideas, then processes the conversation to generate "Zettels" (atomic mini-blog posts). These Zettels are semantically linked using RAG and "Generative Linking."
2. Implementation Tasks (Linear Execution Plan)
Phase 1: Project Setup and Core Utilities
Task 1.1: Project Structure Initialization
Create the following directory and file structure:
Bash
skytalk-api/
├── .venv/
├── src/
│ ├── config.py
│ ├── db/
│ │ ├── database.py
│ │ └── repositories.py
│ ├── services/
│ │ ├── agents/
│ │ │ ├── interviewer.py
│ │ │ ├── synthesizer.py
│ │ │ └── prompts.py
│ │ ├── session_service.py
│ │ └── vector_service.py
│ ├── models.py
│ └── main.py
├── .env
└── requirements.txt
Task 1.2: Dependencies and Configuration
Create requirements.txt:
Plaintext
fastapi
uvicorn[standard]
pydantic-settings
sqlmodel
aiosqlite
langchain
langchain-google-genai
langchain-community
chromadb
tiktoken
Create src/config.py:
Python
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
GOOGLE_API_KEY: str
DATABASE_URL: str = "sqlite+aiosqlite:///./skytalk.db"
CHROMA_PERSIST_DIR: str = "./chroma_db"
LLM_FLASH_MODEL: str = "gemini-2.5-flash-latest"
LLM_PRO_MODEL: str = "gemini-2.5-pro-latest"
EMBEDDING_MODEL: str = "models/text-embedding-004"
settings = Settings()
Phase 2: Data Models and Persistence
Task 2.1: SQLModel Definitions (src/models.py)
Define all data structures, combining database models and API schemas.
Python
import uuid
from datetime import datetime
from typing import List, Optional, Any
from sqlmodel import Field, SQLModel, Relationship, JSON, Column
# --- API Payloads (Not stored in DB) ---
class StartSessionRequest(SQLModel):
topic: str
class SendMessageRequest(SQLModel):
session_id: uuid.UUID
message: str
# --- LLM Structured Output Models ---
class RawZettel(SQLModel):
title: str = Field(description="Concise title of the atomic concept.")
content: str = Field(description="The synthesized mini-blog post.")
tags: List[str] = Field(description="List of relevant conceptual keywords.")
class SegmentationResult(SQLModel):
notes: List[RawZettel]
class RawLink(SQLModel):
target_note_id: str = Field(description="UUID of the related note.")
relationship_context: str = Field(description="Explanation of how the new note relates to the target note.")
class LinkingResult(SQLModel):
links: List[RawLink]
# --- Database Models ---
class Link(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
context: str
source_id: uuid.UUID = Field(foreign_key="note.id")
target_id: uuid.UUID = Field(foreign_key="note.id")
class Note(SQLModel, table=True):
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
title: str
content: str
tags: List[str] = Field(sa_column=Column(JSON))
created_at: datetime = Field(default_factory=datetime.utcnow)
session_id: uuid.UUID = Field(foreign_key="session.id")
session: "Session" = Relationship(back_populates="notes")
class Session(SQLModel, table=True):
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
status: str = Field(default="active") # active, processing, completed, failed
transcript: List[dict] = Field(default=[], sa_column=Column(JSON))
created_at: datetime = Field(default_factory=datetime.utcnow)
notes: List["Note"] = Relationship(back_populates="session")
Task 2.2: Database Initialization (src/db/database.py)
Implement create_async_engine from sqlmodel.ext.asyncio.session.
Define AsyncSession and init_db() to create tables on startup.
Implement a FastAPI dependency get_session() to inject the AsyncSession.
Task 2.3: Repositories (src/db/repositories.py)
Implement async CRUD operations using SQLModel's syntax.
SessionRepository: create, get, update_status, append_transcript.
NoteRepository: create_bulk (efficiently inserts multiple notes).
LinkRepository: create_bulk.
Phase 3: Core Services
Task 3.1: Vector Store Service (src/services/vector_service.py)
Initialize GoogleGenerativeAIEmbeddings from langchain_google_genai.
Initialize a persistent Chroma client.
async add_notes(notes: List[Note]): Adds note content and metadata to Chroma, using note ids as document IDs.
async semantic_search(query: str, k: int = 5): Performs similarity search and returns LangChain Document objects.
Task 3.2: Prompts (src/services/agents/prompts.py)
INTERVIEWER_SYSTEM_PROMPT: Socratic persona. Accepts {retrieved_context}. Instructs AI to ask one question at a time and to output the exact token [END_SESSION] when the conversation naturally concludes.
SEGMENTATION_PROMPT: Instructs LLM to analyze {transcript} and segment it into atomic ideas, enforcing the SegmentationResult structure.
GENERATIVE_LINKING_PROMPT: Instructs LLM to compare {new_note} against {neighbors}, enforcing the LinkingResult structure.
Phase 4: The Agent Layer
Task 4.1: Interviewer Agent (src/services/agents/interviewer.py)
Initialize the Interview LLM (Gemini Flash) using ChatGoogleGenerativeAI.
Dependencies: VectorService.
Method async generate_response(transcript: List[dict], context_query: str):
RAG: Call VectorService.semantic_search(context_query).
LCEL Chain: Define the chain: ChatPromptTemplate | LLM | StrOutputParser().
Invocation: Use .ainvoke() with formatted context and chat history.
Termination Check: Check if the response contains [END_SESSION].
Return the response text and a boolean indicating if the session should end.
Task 4.2: Synthesizer Agent (src/services/agents/synthesizer.py)
Initialize the Synthesis LLM (Gemini Pro).
Dependencies: VectorService.
Method async segment_transcript(transcript: List[dict]) -> SegmentationResult:
Define the structured chain: ChatPromptTemplate | LLM.with_structured_output(schema=SegmentationResult).
Invoke asynchronously and return the result.
Method async generate_links(new_note: Note, neighbors: List[Document]) -> LinkingResult:
Define the structured chain for linking using LinkingResult.
Format new_note and neighbors for the prompt.
Invoke asynchronously and return the result.
Phase 5: Orchestration and API
Task 5.1: Session Orchestration Service (src/services/session_service.py)
This service manages the lifecycle and coordinates agents and repositories.
Dependencies: Repositories, Agents, VectorService.
async start_session(topic: str): Create session in DB. Call InterviewerAgent for the first message. Update transcript.
async handle_message(session_id: uuid.UUID, message: str): Get session. Update transcript with user message. Call InterviewerAgent. Update transcript with AI response. If [END_SESSION] detected, update status to processing.
async process_session_background_task(session_id: uuid.UUID): (The main synthesis pipeline)
Get session transcript.
Call SynthesizerAgent.segment_transcript.
Persistence (SQLite): Save new notes from SegmentationResult using NoteRepository.
Indexing (ChromaDB): Add the newly created notes using VectorService.add_notes.
Linking Loop:
For each new note, call VectorService.semantic_search to find neighbors.
Call SynthesizerAgent.generate_links.
Save links from LinkingResult using LinkRepository.
Finalize: Update session status to completed (or failed on error).
Task 5.2: API Endpoints (src/main.py)
Initialize FastAPI app with a lifespan manager to call init_db().
Implement dependency injection for the DB session and SessionService.
POST /sessions.start: Takes StartSessionRequest, calls SessionService.start_session.
POST /sessions.sendMessage: Takes SendMessageRequest. Calls SessionService.handle_message. If the service indicates the session is ending, trigger the background task.
Background Task Trigger: Use FastAPI's BackgroundTasks to run SessionService.process_session_background_task(session_id). This is critical for returning an immediate API response.
GET /sessions.getStatus: Takes a session_id query parameter and returns the session's current status from the DB.

9
pytest.ini Normal file
View File

@@ -0,0 +1,9 @@
[tool:pytest]
asyncio_mode = auto
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts = -v --tb=short
markers =
asyncio: mark test as an async test

13
requirements.txt Normal file
View File

@@ -0,0 +1,13 @@
fastapi
uvicorn[standard]
pydantic-settings
sqlmodel
aiosqlite
langchain
langchain-google-genai
langchain-community
chromadb
tiktoken
pytest
pytest-asyncio
httpx

132
run_tests.py Normal file
View File

@@ -0,0 +1,132 @@
#!/usr/bin/env python
"""
Test runner script for SkyTalk API.
This script can be run without pytest to validate basic functionality.
"""
import asyncio
import sys
from pathlib import Path
# Add the app directory to Python path
app_dir = Path(__file__).parent
sys.path.insert(0, str(app_dir))
async def test_basic_setup():
"""Test basic setup and imports."""
print("🧪 Testing basic setup...")
try:
from app.core.config import settings
print(f"✅ Config loaded: API key {'set' if settings.GOOGLE_API_KEY and settings.GOOGLE_API_KEY != 'your_google_api_key_here' else 'NOT set'}")
from app.data.models import Session, Note, Link, StartSessionRequest
print("✅ All models imported successfully")
from app.data.database import init_db, get_session
print("✅ Database modules imported successfully")
from app.data.repositories import SessionRepository, NoteRepository, LinkRepository
print("✅ Repository modules imported successfully")
except Exception as e:
print(f"❌ Import error: {e}")
return False
return True
async def test_model_validation():
"""Test model creation and validation."""
print("\n🧪 Testing model validation...")
try:
from app.data.models import StartSessionRequest, Session, Note
import uuid
# Test API models
req = StartSessionRequest(topic="AI Ethics")
print(f"✅ StartSessionRequest: {req.topic}")
# Test database models
session = Session()
print(f"✅ Session created with ID: {session.id}")
note = Note(
title="Test Note",
content="Test content",
tags=["test"],
session_id=session.id
)
print(f"✅ Note created with ID: {note.id}")
except Exception as e:
print(f"❌ Model validation error: {e}")
return False
return True
async def test_database_operations():
"""Test database initialization and basic operations."""
print("\n🧪 Testing database operations...")
try:
from app.data.database import init_db, get_session
from app.data.repositories import SessionRepository
from app.data.models import Session
# Initialize database
await init_db()
print("✅ Database tables created")
# Test repository operations
async with get_session() as db:
repo = SessionRepository(db)
session = await repo.create(status="test")
print(f"✅ Session created via repository: {session.id}")
retrieved = await repo.get(session.id)
print(f"✅ Session retrieved: {retrieved.status}")
except Exception as e:
print(f"❌ Database operation error: {e}")
return False
return True
async def main():
"""Run all tests."""
print("🚀 Running SkyTalk API validation tests...\n")
tests = [
test_basic_setup,
test_model_validation,
test_database_operations
]
passed = 0
total = len(tests)
for test in tests:
if await test():
passed += 1
else:
break # Stop on first failure
print(f"\n📊 Test Results: {passed}/{total} passed")
if passed == total:
print("🎉 All basic validation tests passed!")
print("\nNext steps:")
print("1. Run full test suite: pytest")
print("2. Continue with Phase 3 implementation")
else:
print("❌ Some tests failed. Please fix issues before continuing.")
sys.exit(1)
if __name__ == "__main__":
asyncio.run(main())

0
tests/__init__.py Normal file
View File

24
tests/test_config.py Normal file
View File

@@ -0,0 +1,24 @@
"""Test configuration and environment setup."""
import pytest
from app.core.config import settings
def test_config_loads():
"""Test that configuration loads successfully."""
assert settings.GOOGLE_API_KEY is not None
assert settings.DATABASE_URL.startswith("sqlite+aiosqlite://")
assert settings.CHROMA_PERSIST_DIR == "./chroma_db"
def test_model_names():
"""Test that model names are configured correctly."""
assert "gemini" in settings.LLM_FLASH_MODEL.lower()
assert "gemini" in settings.LLM_PRO_MODEL.lower()
assert "embedding" in settings.EMBEDDING_MODEL.lower()
def test_api_settings():
"""Test API configuration."""
assert isinstance(settings.API_PORT, int)
assert settings.API_PORT > 0
assert settings.API_HOST is not None

73
tests/test_database.py Normal file
View File

@@ -0,0 +1,73 @@
"""Test database initialization and basic operations."""
import pytest
import os
import tempfile
from sqlmodel import SQLModel
from app.data.database import init_db, get_session, engine
from app.data.models import Session, Note, Link
class TestDatabase:
"""Test database operations."""
@pytest.mark.asyncio
async def test_init_db(self):
"""Test database initialization creates tables."""
await init_db()
# Check that tables exist by attempting to create a session
async with get_session() as db:
session = Session(status="test")
db.add(session)
await db.commit()
await db.refresh(session)
assert session.id is not None
@pytest.mark.asyncio
async def test_session_context_manager(self):
"""Test database session context manager."""
await init_db()
# Test successful transaction
async with get_session() as db:
session = Session(status="test")
db.add(session)
# Context manager should auto-commit
# Verify data was saved
async with get_session() as db:
from sqlmodel import select
statement = select(Session).where(Session.status == "test")
result = await db.exec(statement)
found_session = result.first()
assert found_session is not None
assert found_session.status == "test"
@pytest.mark.asyncio
async def test_session_rollback_on_error(self):
"""Test that session rolls back on errors."""
await init_db()
initial_count = 0
async with get_session() as db:
from sqlmodel import select, func
result = await db.exec(select(func.count(Session.id)))
initial_count = result.one()
# Simulate an error during transaction
try:
async with get_session() as db:
session = Session(status="test_rollback")
db.add(session)
# Force an error
raise ValueError("Simulated error")
except ValueError:
pass
# Verify data was not saved due to rollback
async with get_session() as db:
from sqlmodel import select, func
result = await db.exec(select(func.count(Session.id)))
final_count = result.one()
assert final_count == initial_count # No new records

131
tests/test_models.py Normal file
View File

@@ -0,0 +1,131 @@
"""Test data models and validation."""
import uuid
from datetime import datetime
import pytest
from pydantic import ValidationError
from app.data.models import (
StartSessionRequest,
SendMessageRequest,
SessionResponse,
SessionStatusResponse,
RawZettel,
SegmentationResult,
RawLink,
LinkingResult,
Note,
Session,
Link
)
class TestAPIModels:
"""Test API request/response models."""
def test_start_session_request(self):
"""Test StartSessionRequest validation."""
req = StartSessionRequest(topic="AI Ethics")
assert req.topic == "AI Ethics"
# Test empty topic fails
with pytest.raises(ValidationError):
StartSessionRequest(topic="")
def test_send_message_request(self):
"""Test SendMessageRequest validation."""
session_id = uuid.uuid4()
req = SendMessageRequest(session_id=session_id, message="Hello")
assert req.session_id == session_id
assert req.message == "Hello"
# Test invalid UUID fails
with pytest.raises(ValidationError):
SendMessageRequest(session_id="invalid", message="Hello")
def test_session_response(self):
"""Test SessionResponse model."""
session_id = uuid.uuid4()
resp = SessionResponse(
session_id=session_id,
status="active",
message="Welcome!"
)
assert resp.session_id == session_id
assert resp.status == "active"
assert resp.message == "Welcome!"
class TestLLMModels:
"""Test LLM structured output models."""
def test_raw_zettel(self):
"""Test RawZettel model."""
zettel = RawZettel(
title="Test Concept",
content="This is a test concept about AI.",
tags=["ai", "test", "concept"]
)
assert zettel.title == "Test Concept"
assert len(zettel.tags) == 3
assert "ai" in zettel.tags
def test_segmentation_result(self):
"""Test SegmentationResult model."""
zettels = [
RawZettel(title="Concept 1", content="Content 1", tags=["tag1"]),
RawZettel(title="Concept 2", content="Content 2", tags=["tag2"])
]
result = SegmentationResult(notes=zettels)
assert len(result.notes) == 2
assert result.notes[0].title == "Concept 1"
def test_raw_link(self):
"""Test RawLink model."""
link = RawLink(
target_note_id=str(uuid.uuid4()),
relationship_context="Related through shared AI concepts"
)
assert uuid.UUID(link.target_note_id) # Should not raise
assert "AI" in link.relationship_context
class TestDatabaseModels:
"""Test database models."""
def test_session_creation(self):
"""Test Session model creation."""
session = Session()
assert isinstance(session.id, uuid.UUID)
assert session.status == "active"
assert session.transcript == []
assert isinstance(session.created_at, datetime)
def test_note_creation(self):
"""Test Note model creation."""
session_id = uuid.uuid4()
note = Note(
title="Test Note",
content="This is test content",
tags=["test", "note"],
session_id=session_id
)
assert isinstance(note.id, uuid.UUID)
assert note.title == "Test Note"
assert note.content == "This is test content"
assert note.tags == ["test", "note"]
assert note.session_id == session_id
assert isinstance(note.created_at, datetime)
def test_link_creation(self):
"""Test Link model creation."""
source_id = uuid.uuid4()
target_id = uuid.uuid4()
link = Link(
context="These concepts are related through...",
source_id=source_id,
target_id=target_id
)
assert link.context == "These concepts are related through..."
assert link.source_id == source_id
assert link.target_id == target_id

200
tests/test_repositories.py Normal file
View File

@@ -0,0 +1,200 @@
"""Test repository pattern implementations."""
import uuid
import pytest
from app.data.database import init_db, get_session
from app.data.repositories import SessionRepository, NoteRepository, LinkRepository
from app.data.models import Session, Note, Link
class TestSessionRepository:
"""Test SessionRepository operations."""
@pytest.mark.asyncio
async def test_create_session(self):
"""Test creating a new session."""
await init_db()
async with get_session() as db:
repo = SessionRepository(db)
session = await repo.create(status="active")
assert session.id is not None
assert session.status == "active"
assert session.transcript == []
@pytest.mark.asyncio
async def test_get_session(self):
"""Test retrieving a session by ID."""
await init_db()
# Create a session
session_id = None
async with get_session() as db:
repo = SessionRepository(db)
session = await repo.create(status="test")
session_id = session.id
# Retrieve the session
async with get_session() as db:
repo = SessionRepository(db)
retrieved = await repo.get(session_id)
assert retrieved is not None
assert retrieved.id == session_id
assert retrieved.status == "test"
@pytest.mark.asyncio
async def test_update_status(self):
"""Test updating session status."""
await init_db()
# Create a session
session_id = None
async with get_session() as db:
repo = SessionRepository(db)
session = await repo.create(status="active")
session_id = session.id
# Update status
async with get_session() as db:
repo = SessionRepository(db)
updated = await repo.update_status(session_id, "completed")
assert updated is not None
assert updated.status == "completed"
@pytest.mark.asyncio
async def test_append_transcript(self):
"""Test appending to session transcript."""
await init_db()
# Create a session
session_id = None
async with get_session() as db:
repo = SessionRepository(db)
session = await repo.create()
session_id = session.id
# Append to transcript
message = {"role": "user", "content": "Hello"}
async with get_session() as db:
repo = SessionRepository(db)
updated = await repo.append_transcript(session_id, message)
assert updated is not None
assert len(updated.transcript) == 1
assert updated.transcript[0] == message
class TestNoteRepository:
"""Test NoteRepository operations."""
@pytest.mark.asyncio
async def test_create_bulk_notes(self):
"""Test creating multiple notes at once."""
await init_db()
# Create a session first
session_id = None
async with get_session() as db:
session_repo = SessionRepository(db)
session = await session_repo.create()
session_id = session.id
# Create notes
notes = [
Note(
title="Note 1",
content="Content 1",
tags=["tag1"],
session_id=session_id
),
Note(
title="Note 2",
content="Content 2",
tags=["tag2"],
session_id=session_id
)
]
async with get_session() as db:
repo = NoteRepository(db)
created_notes = await repo.create_bulk(notes)
assert len(created_notes) == 2
assert all(note.id is not None for note in created_notes)
assert created_notes[0].title == "Note 1"
assert created_notes[1].title == "Note 2"
@pytest.mark.asyncio
async def test_get_by_session(self):
"""Test retrieving notes by session ID."""
await init_db()
# Create session and notes
session_id = None
async with get_session() as db:
session_repo = SessionRepository(db)
session = await session_repo.create()
session_id = session.id
notes = [
Note(title="Note 1", content="Content 1", tags=["tag1"], session_id=session_id),
Note(title="Note 2", content="Content 2", tags=["tag2"], session_id=session_id)
]
note_repo = NoteRepository(db)
await note_repo.create_bulk(notes)
# Retrieve notes
async with get_session() as db:
repo = NoteRepository(db)
session_notes = await repo.get_by_session(session_id)
assert len(session_notes) == 2
titles = [note.title for note in session_notes]
assert "Note 1" in titles
assert "Note 2" in titles
class TestLinkRepository:
"""Test LinkRepository operations."""
@pytest.mark.asyncio
async def test_create_bulk_links(self):
"""Test creating multiple links at once."""
await init_db()
# Create session and notes first
note_ids = []
async with get_session() as db:
session_repo = SessionRepository(db)
session = await session_repo.create()
notes = [
Note(title="Note 1", content="Content 1", tags=["tag1"], session_id=session.id),
Note(title="Note 2", content="Content 2", tags=["tag2"], session_id=session.id)
]
note_repo = NoteRepository(db)
created_notes = await note_repo.create_bulk(notes)
note_ids = [note.id for note in created_notes]
# Create links
links = [
Link(
context="These notes are related because...",
source_id=note_ids[0],
target_id=note_ids[1]
)
]
async with get_session() as db:
repo = LinkRepository(db)
created_links = await repo.create_bulk(links)
assert len(created_links) == 1
assert created_links[0].id is not None
assert created_links[0].source_id == note_ids[0]
assert created_links[0].target_id == note_ids[1]