From f60d61a78fe21026a8eb78f83b61e66b0b87f16f Mon Sep 17 00:00:00 2001 From: Albert Date: Sun, 17 Aug 2025 01:25:16 +0000 Subject: [PATCH] feat: implement data layer with comprehensive test infrastructure - Define SQLModel schemas for Session, Note, and Link entities - Add API request/response models for RPC endpoints - Create LLM structured output models for Zettel extraction - Set up async database initialization with SQLModel and aiosqlite - Implement repository pattern for CRUD operations - Add complete test suite with pytest configuration - Create validation test runner for development workflow - Add .gitignore for Python/FastAPI project security --- .claude/settings.local.json | 12 ++ .gitignore | 104 ++++++++++++++++ CLAUDE.md | 116 +++++++++++++++++ README.md | 57 +++++++++ app/core/config.py | 19 +++ app/data/database.py | 33 +++++ app/data/models.py | 73 +++++++++++ app/data/repositories.py | 98 +++++++++++++++ implementation-plan.txt | 241 ++++++++++++++++++++++++++++++++++++ pytest.ini | 9 ++ requirements.txt | 13 ++ run_tests.py | 132 ++++++++++++++++++++ tests/__init__.py | 0 tests/test_config.py | 24 ++++ tests/test_database.py | 73 +++++++++++ tests/test_models.py | 131 ++++++++++++++++++++ tests/test_repositories.py | 200 ++++++++++++++++++++++++++++++ 17 files changed, 1335 insertions(+) create mode 100644 .claude/settings.local.json create mode 100644 .gitignore create mode 100644 CLAUDE.md create mode 100644 app/core/config.py create mode 100644 app/data/database.py create mode 100644 app/data/models.py create mode 100644 app/data/repositories.py create mode 100644 implementation-plan.txt create mode 100644 pytest.ini create mode 100644 requirements.txt create mode 100644 run_tests.py create mode 100644 tests/__init__.py create mode 100644 tests/test_config.py create mode 100644 tests/test_database.py create mode 100644 tests/test_models.py create mode 100644 tests/test_repositories.py diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..7f80347 --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,12 @@ +{ + "permissions": { + "allow": [ + "Bash(mkdir:*)", + "Bash(python run_tests.py:*)", + "Bash(source:*)", + "Bash(pytest:*)" + ], + "deny": [], + "ask": [] + } +} \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..347516d --- /dev/null +++ b/.gitignore @@ -0,0 +1,104 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Virtual environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# IDE specific files +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# OS specific files +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# Project specific files +*.db +*.sqlite +*.sqlite3 +skytalk.db* + +# ChromaDB vector database +chroma_db/ +*.chroma + +# Logs +*.log +logs/ + +# Environment variables (keep template) +.env.local +.env.production +.env.development + +# FastAPI specific +.mypy_cache/ +.dmypy.json +dmypy.json + +# Jupyter Notebooks +.ipynb_checkpoints +*.ipynb + +# Documentation builds +docs/_build/ +site/ \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..966ff5d --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,116 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with +code in this repository. + +## Project Overview + +SkyTalk API is an AI-powered backend service for conversational idea exploration +and knowledge synthesis. It interviews users to explore ideas and synthesizes +conversations into a structured, semantically-linked knowledge base using the +Zettelkasten method. + +## Core Development Commands + +### Environment Setup + +```bash +# Install dependencies using uv package manager +uv pip install -r requirements.txt + +# Run the API server (async FastAPI) +uvicorn app.main:app --reload --host 0.0.0.0 --port 8000 +``` + +### Code Quality + +```bash +# Format code with black (mandatory) +black . + +# Type checking +mypy . + +# Linting +ruff check . +``` + +## Architecture & Implementation Standards + +### Three-Layer Architecture + +1. **API Layer**: FastAPI RPC-style endpoints (`/sessions/start`, + `/sessions/message`) +2. **Services Layer**: Orchestration logic with LangChain LCEL +3. **Data Layer**: SQLModel (SQLite) + ChromaDB (embedded, persisted) + +### Critical Implementation Rules + +**Database Models**: + +- SQLModel is the single source of truth for all database schemas +- ChromaDB runs in embedded mode, persisted to disk +- All database operations must be async + +**AI Integration**: + +- **Interviewing**: Use `gemini-2.5-flash-latest` (optimized for speed) +- **Synthesis/Linking**: Use `gemini-2.5-pro-latest` (optimized for reasoning) +- **Embeddings**: Use `models/text-embedding-004` +- **Structured Output**: Always use `.with_structured_output()` with Pydantic + models for data extraction - never parse raw text + +**Code Standards**: + +- Maximum 400 lines per Python file +- Full type hints required (Pydantic V2 for API, SQLModel for DB) +- All I/O operations must use async/await +- Configuration via environment variables with `pydantic-settings` +- Use `HTTPException` for client errors + +### Project Structure + +``` +api/ +├── app/ +│ ├── api/ # FastAPI endpoints (RPC-style) +│ ├── services/ # Business logic, LangChain agents +│ │ ├── interviewer.py +│ │ ├── synthesizer.py +│ │ └── vector.py +│ ├── data/ # Repositories and database models +│ │ ├── models/ # SQLModel definitions +│ │ └── repositories/ +│ ├── core/ # Configuration, prompts +│ └── main.py # FastAPI app initialization +├── requirements.txt # Dependencies managed via uv +└── .env # Environment variables +``` + +### Key Implementation Patterns + +**LangChain LCEL Pipeline**: + +```python +chain = prompt | llm.with_structured_output(OutputModel) | parser +result = await chain.ainvoke({"input": data}) +``` + +**Async Database Operations**: + +```python +async def get_session(session_id: str) -> Session: + async with get_session() as db: + return await db.get(Session, session_id) +``` + +**Background Task for Synthesis**: + +```python +background_tasks.add_task(synthesize_session, session_id) +``` + +### Overall Style + +- NO sycophancy -- push back/suggest alternative routes when it would help + improve the project diff --git a/README.md b/README.md index 27139b1..6eda783 100644 --- a/README.md +++ b/README.md @@ -81,3 +81,60 @@ SkyTalkAPI: { # Connections to External APIs SkyTalkAPI.Services -> ExternalAPIs.Gemini ``` + +--- + +## MVP Features (Phase 1) + +At the completion of this implementation plan, the SkyTalk API will support the +following core features: + +- **Session Management:** Start a new interview session based on an initial + topic. + +- **RAG-Powered Interviewing:** Engage in a back-and-forth conversation where + the AI's questions are informed by existing knowledge in the vector store. + +- **Automatic Session Termination:** The AI can detect a natural conclusion to + the conversation. + +- **Asynchronous Synthesis:** Once the interview ends, a background process is + triggered to analyze the transcript. + +- **Semantic Segmentation:** The transcript is intelligently broken down into + atomic "Zettels" (notes), each focusing on a single concept. + +- **Vector Indexing:** Each new note is converted into a vector embedding and + stored for future RAG. + +- **Generative Linking:** The system identifies semantically related notes and + uses an LLM to generate a rich, contextual link explaining the relationship + between them. + +- **Status Tracking:** Endpoints to check the status of a session (active, + processing, completed). + +--- + +## Next Steps (Post-MVP) + +- **Authentication:** Implement user accounts and authentication (e.g., JWT) to + create user-specific knowledge bases. + +- **Frontend Integration:** Build a web-based frontend (e.g., at + www.skytalk.app) that consumes this API. + +- **Knowledge Graph Visualization:** Add endpoints to export the note-and-link + structure in a format suitable for graph visualization libraries (e.g., D3.js, + Vis.js). + +- **Note Editing and Management:** Provide endpoints for users to manually edit, + delete, or merge notes. + +- **Advanced Search:** Implement more sophisticated search functionalities + beyond simple semantic search, such as filtering by tags or searching within + link contexts. + +- **Scalable Infrastructure:** Migrate from SQLite/embedded ChromaDB to a + production-grade database (e.g., PostgreSQL with pgvector) and a managed + vector database for scalability. diff --git a/app/core/config.py b/app/core/config.py new file mode 100644 index 0000000..eaf6318 --- /dev/null +++ b/app/core/config.py @@ -0,0 +1,19 @@ +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8") + + GOOGLE_API_KEY: str + DATABASE_URL: str = "sqlite+aiosqlite:///./skytalk.db" + CHROMA_PERSIST_DIR: str = "./chroma_db" + + LLM_FLASH_MODEL: str = "gemini-2.5-flash-latest" + LLM_PRO_MODEL: str = "gemini-2.5-pro-latest" + EMBEDDING_MODEL: str = "models/text-embedding-004" + + API_HOST: str = "0.0.0.0" + API_PORT: int = 8000 + + +settings = Settings() \ No newline at end of file diff --git a/app/data/database.py b/app/data/database.py new file mode 100644 index 0000000..ebfe75a --- /dev/null +++ b/app/data/database.py @@ -0,0 +1,33 @@ +from contextlib import asynccontextmanager +from typing import AsyncGenerator + +from sqlalchemy.ext.asyncio import create_async_engine +from sqlmodel import SQLModel +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.core.config import settings + + +engine = create_async_engine( + settings.DATABASE_URL, + echo=False, + future=True +) + + +async def init_db() -> None: + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + + +@asynccontextmanager +async def get_session() -> AsyncGenerator[AsyncSession, None]: + async with AsyncSession(engine) as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + finally: + await session.close() \ No newline at end of file diff --git a/app/data/models.py b/app/data/models.py new file mode 100644 index 0000000..09c10c5 --- /dev/null +++ b/app/data/models.py @@ -0,0 +1,73 @@ +import uuid +from datetime import datetime +from typing import List, Optional, Any +from sqlmodel import Field, SQLModel, Relationship, JSON, Column + + +class StartSessionRequest(SQLModel): + topic: str = Field(min_length=1, description="Topic for the session") + + +class SendMessageRequest(SQLModel): + session_id: uuid.UUID + message: str = Field(min_length=1, description="User message content") + + +class SessionResponse(SQLModel): + session_id: uuid.UUID + status: str + message: str + + +class SessionStatusResponse(SQLModel): + session_id: uuid.UUID + status: str + notes_count: int = 0 + created_at: datetime + + +class RawZettel(SQLModel): + title: str = Field(description="Concise title of the atomic concept.") + content: str = Field(description="The synthesized mini-blog post.") + tags: List[str] = Field(description="List of relevant conceptual keywords.") + + +class SegmentationResult(SQLModel): + notes: List[RawZettel] + + +class RawLink(SQLModel): + target_note_id: str = Field(description="UUID of the related note.") + relationship_context: str = Field(description="Explanation of how the new note relates to the target note.") + + +class LinkingResult(SQLModel): + links: List[RawLink] + + +class Link(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + context: str + + source_id: uuid.UUID = Field(foreign_key="note.id") + target_id: uuid.UUID = Field(foreign_key="note.id") + + +class Note(SQLModel, table=True): + id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + title: str + content: str + tags: List[str] = Field(sa_column=Column(JSON)) + created_at: datetime = Field(default_factory=datetime.utcnow) + + session_id: uuid.UUID = Field(foreign_key="session.id") + session: Optional["Session"] = Relationship(back_populates="notes") + + +class Session(SQLModel, table=True): + id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + status: str = Field(default="active") + transcript: List[dict] = Field(default=[], sa_column=Column(JSON)) + created_at: datetime = Field(default_factory=datetime.utcnow) + + notes: List["Note"] = Relationship(back_populates="session") \ No newline at end of file diff --git a/app/data/repositories.py b/app/data/repositories.py new file mode 100644 index 0000000..3f8a92d --- /dev/null +++ b/app/data/repositories.py @@ -0,0 +1,98 @@ +import uuid +from typing import List, Optional + +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.data.models import Session, Note, Link + + +class SessionRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def create(self, status: str = "active") -> Session: + db_session = Session(status=status) + self.session.add(db_session) + await self.session.commit() + await self.session.refresh(db_session) + return db_session + + async def get(self, session_id: uuid.UUID) -> Optional[Session]: + statement = select(Session).where(Session.id == session_id) + result = await self.session.exec(statement) + return result.first() + + async def update_status(self, session_id: uuid.UUID, status: str) -> Optional[Session]: + db_session = await self.get(session_id) + if db_session: + db_session.status = status + await self.session.commit() + await self.session.refresh(db_session) + return db_session + + async def append_transcript(self, session_id: uuid.UUID, message: dict) -> Optional[Session]: + db_session = await self.get(session_id) + if db_session: + transcript = db_session.transcript.copy() + transcript.append(message) + db_session.transcript = transcript + await self.session.commit() + await self.session.refresh(db_session) + return db_session + + async def get_with_notes(self, session_id: uuid.UUID) -> Optional[Session]: + statement = select(Session).where(Session.id == session_id) + result = await self.session.exec(statement) + session = result.first() + if session: + notes_statement = select(Note).where(Note.session_id == session_id) + notes_result = await self.session.exec(notes_statement) + session.notes = notes_result.all() + return session + + +class NoteRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def create_bulk(self, notes: List[Note]) -> List[Note]: + for note in notes: + self.session.add(note) + await self.session.commit() + for note in notes: + await self.session.refresh(note) + return notes + + async def get(self, note_id: uuid.UUID) -> Optional[Note]: + statement = select(Note).where(Note.id == note_id) + result = await self.session.exec(statement) + return result.first() + + async def get_by_session(self, session_id: uuid.UUID) -> List[Note]: + statement = select(Note).where(Note.session_id == session_id) + result = await self.session.exec(statement) + return result.all() + + +class LinkRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def create_bulk(self, links: List[Link]) -> List[Link]: + for link in links: + self.session.add(link) + await self.session.commit() + for link in links: + await self.session.refresh(link) + return links + + async def get_by_source(self, source_id: uuid.UUID) -> List[Link]: + statement = select(Link).where(Link.source_id == source_id) + result = await self.session.exec(statement) + return result.all() + + async def get_by_target(self, target_id: uuid.UUID) -> List[Link]: + statement = select(Link).where(Link.target_id == target_id) + result = await self.session.exec(statement) + return result.all() \ No newline at end of file diff --git a/implementation-plan.txt b/implementation-plan.txt new file mode 100644 index 0000000..c7e8c59 --- /dev/null +++ b/implementation-plan.txt @@ -0,0 +1,241 @@ +Product Requirements Document (PRD) & Implementation Plan +1. Overview +SkyTalk (api.skytalk.app) is an API service that functions as an AI interviewer to elicit user ideas, then processes the conversation to generate "Zettels" (atomic mini-blog posts). These Zettels are semantically linked using RAG and "Generative Linking." + +2. Implementation Tasks (Linear Execution Plan) +Phase 1: Project Setup and Core Utilities +Task 1.1: Project Structure Initialization +Create the following directory and file structure: + +Bash + +skytalk-api/ +├── .venv/ +├── src/ +│ ├── config.py +│ ├── db/ +│ │ ├── database.py +│ │ └── repositories.py +│ ├── services/ +│ │ ├── agents/ +│ │ │ ├── interviewer.py +│ │ │ ├── synthesizer.py +│ │ │ └── prompts.py +│ │ ├── session_service.py +│ │ └── vector_service.py +│ ├── models.py +│ └── main.py +├── .env +└── requirements.txt +Task 1.2: Dependencies and Configuration + +Create requirements.txt: + +Plaintext + +fastapi +uvicorn[standard] +pydantic-settings +sqlmodel +aiosqlite +langchain +langchain-google-genai +langchain-community +chromadb +tiktoken +Create src/config.py: + +Python + +from pydantic_settings import BaseSettings, SettingsConfigDict + +class Settings(BaseSettings): + model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8") + + GOOGLE_API_KEY: str + DATABASE_URL: str = "sqlite+aiosqlite:///./skytalk.db" + CHROMA_PERSIST_DIR: str = "./chroma_db" + + LLM_FLASH_MODEL: str = "gemini-2.5-flash-latest" + LLM_PRO_MODEL: str = "gemini-2.5-pro-latest" + EMBEDDING_MODEL: str = "models/text-embedding-004" + +settings = Settings() +Phase 2: Data Models and Persistence +Task 2.1: SQLModel Definitions (src/models.py) +Define all data structures, combining database models and API schemas. + +Python + +import uuid +from datetime import datetime +from typing import List, Optional, Any +from sqlmodel import Field, SQLModel, Relationship, JSON, Column + +# --- API Payloads (Not stored in DB) --- +class StartSessionRequest(SQLModel): + topic: str + +class SendMessageRequest(SQLModel): + session_id: uuid.UUID + message: str + +# --- LLM Structured Output Models --- +class RawZettel(SQLModel): + title: str = Field(description="Concise title of the atomic concept.") + content: str = Field(description="The synthesized mini-blog post.") + tags: List[str] = Field(description="List of relevant conceptual keywords.") + +class SegmentationResult(SQLModel): + notes: List[RawZettel] + +class RawLink(SQLModel): + target_note_id: str = Field(description="UUID of the related note.") + relationship_context: str = Field(description="Explanation of how the new note relates to the target note.") + +class LinkingResult(SQLModel): + links: List[RawLink] + +# --- Database Models --- +class Link(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + context: str + + source_id: uuid.UUID = Field(foreign_key="note.id") + target_id: uuid.UUID = Field(foreign_key="note.id") + +class Note(SQLModel, table=True): + id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + title: str + content: str + tags: List[str] = Field(sa_column=Column(JSON)) + created_at: datetime = Field(default_factory=datetime.utcnow) + + session_id: uuid.UUID = Field(foreign_key="session.id") + session: "Session" = Relationship(back_populates="notes") + +class Session(SQLModel, table=True): + id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + status: str = Field(default="active") # active, processing, completed, failed + transcript: List[dict] = Field(default=[], sa_column=Column(JSON)) + created_at: datetime = Field(default_factory=datetime.utcnow) + + notes: List["Note"] = Relationship(back_populates="session") +Task 2.2: Database Initialization (src/db/database.py) + +Implement create_async_engine from sqlmodel.ext.asyncio.session. + +Define AsyncSession and init_db() to create tables on startup. + +Implement a FastAPI dependency get_session() to inject the AsyncSession. + +Task 2.3: Repositories (src/db/repositories.py) +Implement async CRUD operations using SQLModel's syntax. + +SessionRepository: create, get, update_status, append_transcript. + +NoteRepository: create_bulk (efficiently inserts multiple notes). + +LinkRepository: create_bulk. + +Phase 3: Core Services +Task 3.1: Vector Store Service (src/services/vector_service.py) + +Initialize GoogleGenerativeAIEmbeddings from langchain_google_genai. + +Initialize a persistent Chroma client. + +async add_notes(notes: List[Note]): Adds note content and metadata to Chroma, using note ids as document IDs. + +async semantic_search(query: str, k: int = 5): Performs similarity search and returns LangChain Document objects. + +Task 3.2: Prompts (src/services/agents/prompts.py) + +INTERVIEWER_SYSTEM_PROMPT: Socratic persona. Accepts {retrieved_context}. Instructs AI to ask one question at a time and to output the exact token [END_SESSION] when the conversation naturally concludes. + +SEGMENTATION_PROMPT: Instructs LLM to analyze {transcript} and segment it into atomic ideas, enforcing the SegmentationResult structure. + +GENERATIVE_LINKING_PROMPT: Instructs LLM to compare {new_note} against {neighbors}, enforcing the LinkingResult structure. + +Phase 4: The Agent Layer +Task 4.1: Interviewer Agent (src/services/agents/interviewer.py) + +Initialize the Interview LLM (Gemini Flash) using ChatGoogleGenerativeAI. + +Dependencies: VectorService. + +Method async generate_response(transcript: List[dict], context_query: str): + +RAG: Call VectorService.semantic_search(context_query). + +LCEL Chain: Define the chain: ChatPromptTemplate | LLM | StrOutputParser(). + +Invocation: Use .ainvoke() with formatted context and chat history. + +Termination Check: Check if the response contains [END_SESSION]. + +Return the response text and a boolean indicating if the session should end. + +Task 4.2: Synthesizer Agent (src/services/agents/synthesizer.py) + +Initialize the Synthesis LLM (Gemini Pro). + +Dependencies: VectorService. + +Method async segment_transcript(transcript: List[dict]) -> SegmentationResult: + +Define the structured chain: ChatPromptTemplate | LLM.with_structured_output(schema=SegmentationResult). + +Invoke asynchronously and return the result. + +Method async generate_links(new_note: Note, neighbors: List[Document]) -> LinkingResult: + +Define the structured chain for linking using LinkingResult. + +Format new_note and neighbors for the prompt. + +Invoke asynchronously and return the result. + +Phase 5: Orchestration and API +Task 5.1: Session Orchestration Service (src/services/session_service.py) +This service manages the lifecycle and coordinates agents and repositories. + +Dependencies: Repositories, Agents, VectorService. + +async start_session(topic: str): Create session in DB. Call InterviewerAgent for the first message. Update transcript. + +async handle_message(session_id: uuid.UUID, message: str): Get session. Update transcript with user message. Call InterviewerAgent. Update transcript with AI response. If [END_SESSION] detected, update status to processing. + +async process_session_background_task(session_id: uuid.UUID): (The main synthesis pipeline) + +Get session transcript. + +Call SynthesizerAgent.segment_transcript. + +Persistence (SQLite): Save new notes from SegmentationResult using NoteRepository. + +Indexing (ChromaDB): Add the newly created notes using VectorService.add_notes. + +Linking Loop: + +For each new note, call VectorService.semantic_search to find neighbors. + +Call SynthesizerAgent.generate_links. + +Save links from LinkingResult using LinkRepository. + +Finalize: Update session status to completed (or failed on error). + +Task 5.2: API Endpoints (src/main.py) + +Initialize FastAPI app with a lifespan manager to call init_db(). + +Implement dependency injection for the DB session and SessionService. + +POST /sessions.start: Takes StartSessionRequest, calls SessionService.start_session. + +POST /sessions.sendMessage: Takes SendMessageRequest. Calls SessionService.handle_message. If the service indicates the session is ending, trigger the background task. + +Background Task Trigger: Use FastAPI's BackgroundTasks to run SessionService.process_session_background_task(session_id). This is critical for returning an immediate API response. + +GET /sessions.getStatus: Takes a session_id query parameter and returns the session's current status from the DB. diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..f6bb7cb --- /dev/null +++ b/pytest.ini @@ -0,0 +1,9 @@ +[tool:pytest] +asyncio_mode = auto +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* +addopts = -v --tb=short +markers = + asyncio: mark test as an async test \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f01cde4 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,13 @@ +fastapi +uvicorn[standard] +pydantic-settings +sqlmodel +aiosqlite +langchain +langchain-google-genai +langchain-community +chromadb +tiktoken +pytest +pytest-asyncio +httpx \ No newline at end of file diff --git a/run_tests.py b/run_tests.py new file mode 100644 index 0000000..01fbc96 --- /dev/null +++ b/run_tests.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python +""" +Test runner script for SkyTalk API. + +This script can be run without pytest to validate basic functionality. +""" +import asyncio +import sys +from pathlib import Path + +# Add the app directory to Python path +app_dir = Path(__file__).parent +sys.path.insert(0, str(app_dir)) + + +async def test_basic_setup(): + """Test basic setup and imports.""" + print("🧪 Testing basic setup...") + + try: + from app.core.config import settings + print(f"✅ Config loaded: API key {'set' if settings.GOOGLE_API_KEY and settings.GOOGLE_API_KEY != 'your_google_api_key_here' else 'NOT set'}") + + from app.data.models import Session, Note, Link, StartSessionRequest + print("✅ All models imported successfully") + + from app.data.database import init_db, get_session + print("✅ Database modules imported successfully") + + from app.data.repositories import SessionRepository, NoteRepository, LinkRepository + print("✅ Repository modules imported successfully") + + except Exception as e: + print(f"❌ Import error: {e}") + return False + + return True + + +async def test_model_validation(): + """Test model creation and validation.""" + print("\n🧪 Testing model validation...") + + try: + from app.data.models import StartSessionRequest, Session, Note + import uuid + + # Test API models + req = StartSessionRequest(topic="AI Ethics") + print(f"✅ StartSessionRequest: {req.topic}") + + # Test database models + session = Session() + print(f"✅ Session created with ID: {session.id}") + + note = Note( + title="Test Note", + content="Test content", + tags=["test"], + session_id=session.id + ) + print(f"✅ Note created with ID: {note.id}") + + except Exception as e: + print(f"❌ Model validation error: {e}") + return False + + return True + + +async def test_database_operations(): + """Test database initialization and basic operations.""" + print("\n🧪 Testing database operations...") + + try: + from app.data.database import init_db, get_session + from app.data.repositories import SessionRepository + from app.data.models import Session + + # Initialize database + await init_db() + print("✅ Database tables created") + + # Test repository operations + async with get_session() as db: + repo = SessionRepository(db) + session = await repo.create(status="test") + print(f"✅ Session created via repository: {session.id}") + + retrieved = await repo.get(session.id) + print(f"✅ Session retrieved: {retrieved.status}") + + except Exception as e: + print(f"❌ Database operation error: {e}") + return False + + return True + + +async def main(): + """Run all tests.""" + print("🚀 Running SkyTalk API validation tests...\n") + + tests = [ + test_basic_setup, + test_model_validation, + test_database_operations + ] + + passed = 0 + total = len(tests) + + for test in tests: + if await test(): + passed += 1 + else: + break # Stop on first failure + + print(f"\n📊 Test Results: {passed}/{total} passed") + + if passed == total: + print("🎉 All basic validation tests passed!") + print("\nNext steps:") + print("1. Run full test suite: pytest") + print("2. Continue with Phase 3 implementation") + else: + print("❌ Some tests failed. Please fix issues before continuing.") + sys.exit(1) + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..58d97aa --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,24 @@ +"""Test configuration and environment setup.""" +import pytest +from app.core.config import settings + + +def test_config_loads(): + """Test that configuration loads successfully.""" + assert settings.GOOGLE_API_KEY is not None + assert settings.DATABASE_URL.startswith("sqlite+aiosqlite://") + assert settings.CHROMA_PERSIST_DIR == "./chroma_db" + + +def test_model_names(): + """Test that model names are configured correctly.""" + assert "gemini" in settings.LLM_FLASH_MODEL.lower() + assert "gemini" in settings.LLM_PRO_MODEL.lower() + assert "embedding" in settings.EMBEDDING_MODEL.lower() + + +def test_api_settings(): + """Test API configuration.""" + assert isinstance(settings.API_PORT, int) + assert settings.API_PORT > 0 + assert settings.API_HOST is not None \ No newline at end of file diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 0000000..aaff9be --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,73 @@ +"""Test database initialization and basic operations.""" +import pytest +import os +import tempfile +from sqlmodel import SQLModel + +from app.data.database import init_db, get_session, engine +from app.data.models import Session, Note, Link + + +class TestDatabase: + """Test database operations.""" + + @pytest.mark.asyncio + async def test_init_db(self): + """Test database initialization creates tables.""" + await init_db() + + # Check that tables exist by attempting to create a session + async with get_session() as db: + session = Session(status="test") + db.add(session) + await db.commit() + await db.refresh(session) + assert session.id is not None + + @pytest.mark.asyncio + async def test_session_context_manager(self): + """Test database session context manager.""" + await init_db() + + # Test successful transaction + async with get_session() as db: + session = Session(status="test") + db.add(session) + # Context manager should auto-commit + + # Verify data was saved + async with get_session() as db: + from sqlmodel import select + statement = select(Session).where(Session.status == "test") + result = await db.exec(statement) + found_session = result.first() + assert found_session is not None + assert found_session.status == "test" + + @pytest.mark.asyncio + async def test_session_rollback_on_error(self): + """Test that session rolls back on errors.""" + await init_db() + + initial_count = 0 + async with get_session() as db: + from sqlmodel import select, func + result = await db.exec(select(func.count(Session.id))) + initial_count = result.one() + + # Simulate an error during transaction + try: + async with get_session() as db: + session = Session(status="test_rollback") + db.add(session) + # Force an error + raise ValueError("Simulated error") + except ValueError: + pass + + # Verify data was not saved due to rollback + async with get_session() as db: + from sqlmodel import select, func + result = await db.exec(select(func.count(Session.id))) + final_count = result.one() + assert final_count == initial_count # No new records \ No newline at end of file diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..938617d --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,131 @@ +"""Test data models and validation.""" +import uuid +from datetime import datetime + +import pytest +from pydantic import ValidationError + +from app.data.models import ( + StartSessionRequest, + SendMessageRequest, + SessionResponse, + SessionStatusResponse, + RawZettel, + SegmentationResult, + RawLink, + LinkingResult, + Note, + Session, + Link +) + + +class TestAPIModels: + """Test API request/response models.""" + + def test_start_session_request(self): + """Test StartSessionRequest validation.""" + req = StartSessionRequest(topic="AI Ethics") + assert req.topic == "AI Ethics" + + # Test empty topic fails + with pytest.raises(ValidationError): + StartSessionRequest(topic="") + + def test_send_message_request(self): + """Test SendMessageRequest validation.""" + session_id = uuid.uuid4() + req = SendMessageRequest(session_id=session_id, message="Hello") + assert req.session_id == session_id + assert req.message == "Hello" + + # Test invalid UUID fails + with pytest.raises(ValidationError): + SendMessageRequest(session_id="invalid", message="Hello") + + def test_session_response(self): + """Test SessionResponse model.""" + session_id = uuid.uuid4() + resp = SessionResponse( + session_id=session_id, + status="active", + message="Welcome!" + ) + assert resp.session_id == session_id + assert resp.status == "active" + assert resp.message == "Welcome!" + + +class TestLLMModels: + """Test LLM structured output models.""" + + def test_raw_zettel(self): + """Test RawZettel model.""" + zettel = RawZettel( + title="Test Concept", + content="This is a test concept about AI.", + tags=["ai", "test", "concept"] + ) + assert zettel.title == "Test Concept" + assert len(zettel.tags) == 3 + assert "ai" in zettel.tags + + def test_segmentation_result(self): + """Test SegmentationResult model.""" + zettels = [ + RawZettel(title="Concept 1", content="Content 1", tags=["tag1"]), + RawZettel(title="Concept 2", content="Content 2", tags=["tag2"]) + ] + result = SegmentationResult(notes=zettels) + assert len(result.notes) == 2 + assert result.notes[0].title == "Concept 1" + + def test_raw_link(self): + """Test RawLink model.""" + link = RawLink( + target_note_id=str(uuid.uuid4()), + relationship_context="Related through shared AI concepts" + ) + assert uuid.UUID(link.target_note_id) # Should not raise + assert "AI" in link.relationship_context + + +class TestDatabaseModels: + """Test database models.""" + + def test_session_creation(self): + """Test Session model creation.""" + session = Session() + assert isinstance(session.id, uuid.UUID) + assert session.status == "active" + assert session.transcript == [] + assert isinstance(session.created_at, datetime) + + def test_note_creation(self): + """Test Note model creation.""" + session_id = uuid.uuid4() + note = Note( + title="Test Note", + content="This is test content", + tags=["test", "note"], + session_id=session_id + ) + assert isinstance(note.id, uuid.UUID) + assert note.title == "Test Note" + assert note.content == "This is test content" + assert note.tags == ["test", "note"] + assert note.session_id == session_id + assert isinstance(note.created_at, datetime) + + def test_link_creation(self): + """Test Link model creation.""" + source_id = uuid.uuid4() + target_id = uuid.uuid4() + link = Link( + context="These concepts are related through...", + source_id=source_id, + target_id=target_id + ) + assert link.context == "These concepts are related through..." + assert link.source_id == source_id + assert link.target_id == target_id \ No newline at end of file diff --git a/tests/test_repositories.py b/tests/test_repositories.py new file mode 100644 index 0000000..a7dce19 --- /dev/null +++ b/tests/test_repositories.py @@ -0,0 +1,200 @@ +"""Test repository pattern implementations.""" +import uuid +import pytest + +from app.data.database import init_db, get_session +from app.data.repositories import SessionRepository, NoteRepository, LinkRepository +from app.data.models import Session, Note, Link + + +class TestSessionRepository: + """Test SessionRepository operations.""" + + @pytest.mark.asyncio + async def test_create_session(self): + """Test creating a new session.""" + await init_db() + + async with get_session() as db: + repo = SessionRepository(db) + session = await repo.create(status="active") + + assert session.id is not None + assert session.status == "active" + assert session.transcript == [] + + @pytest.mark.asyncio + async def test_get_session(self): + """Test retrieving a session by ID.""" + await init_db() + + # Create a session + session_id = None + async with get_session() as db: + repo = SessionRepository(db) + session = await repo.create(status="test") + session_id = session.id + + # Retrieve the session + async with get_session() as db: + repo = SessionRepository(db) + retrieved = await repo.get(session_id) + + assert retrieved is not None + assert retrieved.id == session_id + assert retrieved.status == "test" + + @pytest.mark.asyncio + async def test_update_status(self): + """Test updating session status.""" + await init_db() + + # Create a session + session_id = None + async with get_session() as db: + repo = SessionRepository(db) + session = await repo.create(status="active") + session_id = session.id + + # Update status + async with get_session() as db: + repo = SessionRepository(db) + updated = await repo.update_status(session_id, "completed") + + assert updated is not None + assert updated.status == "completed" + + @pytest.mark.asyncio + async def test_append_transcript(self): + """Test appending to session transcript.""" + await init_db() + + # Create a session + session_id = None + async with get_session() as db: + repo = SessionRepository(db) + session = await repo.create() + session_id = session.id + + # Append to transcript + message = {"role": "user", "content": "Hello"} + async with get_session() as db: + repo = SessionRepository(db) + updated = await repo.append_transcript(session_id, message) + + assert updated is not None + assert len(updated.transcript) == 1 + assert updated.transcript[0] == message + + +class TestNoteRepository: + """Test NoteRepository operations.""" + + @pytest.mark.asyncio + async def test_create_bulk_notes(self): + """Test creating multiple notes at once.""" + await init_db() + + # Create a session first + session_id = None + async with get_session() as db: + session_repo = SessionRepository(db) + session = await session_repo.create() + session_id = session.id + + # Create notes + notes = [ + Note( + title="Note 1", + content="Content 1", + tags=["tag1"], + session_id=session_id + ), + Note( + title="Note 2", + content="Content 2", + tags=["tag2"], + session_id=session_id + ) + ] + + async with get_session() as db: + repo = NoteRepository(db) + created_notes = await repo.create_bulk(notes) + + assert len(created_notes) == 2 + assert all(note.id is not None for note in created_notes) + assert created_notes[0].title == "Note 1" + assert created_notes[1].title == "Note 2" + + @pytest.mark.asyncio + async def test_get_by_session(self): + """Test retrieving notes by session ID.""" + await init_db() + + # Create session and notes + session_id = None + async with get_session() as db: + session_repo = SessionRepository(db) + session = await session_repo.create() + session_id = session.id + + notes = [ + Note(title="Note 1", content="Content 1", tags=["tag1"], session_id=session_id), + Note(title="Note 2", content="Content 2", tags=["tag2"], session_id=session_id) + ] + + note_repo = NoteRepository(db) + await note_repo.create_bulk(notes) + + # Retrieve notes + async with get_session() as db: + repo = NoteRepository(db) + session_notes = await repo.get_by_session(session_id) + + assert len(session_notes) == 2 + titles = [note.title for note in session_notes] + assert "Note 1" in titles + assert "Note 2" in titles + + +class TestLinkRepository: + """Test LinkRepository operations.""" + + @pytest.mark.asyncio + async def test_create_bulk_links(self): + """Test creating multiple links at once.""" + await init_db() + + # Create session and notes first + note_ids = [] + async with get_session() as db: + session_repo = SessionRepository(db) + session = await session_repo.create() + + notes = [ + Note(title="Note 1", content="Content 1", tags=["tag1"], session_id=session.id), + Note(title="Note 2", content="Content 2", tags=["tag2"], session_id=session.id) + ] + + note_repo = NoteRepository(db) + created_notes = await note_repo.create_bulk(notes) + note_ids = [note.id for note in created_notes] + + # Create links + links = [ + Link( + context="These notes are related because...", + source_id=note_ids[0], + target_id=note_ids[1] + ) + ] + + async with get_session() as db: + repo = LinkRepository(db) + created_links = await repo.create_bulk(links) + + assert len(created_links) == 1 + assert created_links[0].id is not None + assert created_links[0].source_id == note_ids[0] + assert created_links[0].target_id == note_ids[1] \ No newline at end of file