feat: implement AI agent layer with LangChain integration
- Create InterviewerAgent with Socratic questioning and RAG context - Build SynthesizerAgent for transcript segmentation and link generation - Integrate Google Gemini models (Flash for interviewing, Pro for synthesis) - Add structured output parsing for Zettel extraction and linking - Implement session termination detection with [END_SESSION] token - Add conversation context formatting and similarity-based neighbor filtering - Add vector service tests with mocked ChromaDB and embeddings - Test interviewer agent RAG conversations and session termination - Test synthesizer agent transcript formatting and neighbor analysis - Add prompt loader tests for external prompt system - Test all repository CRUD operations and database transactions
This commit is contained in:
96
app/services/interviewer.py
Normal file
96
app/services/interviewer.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.prompt_loader import prompt_loader
|
||||
from app.services.vector import VectorService
|
||||
|
||||
|
||||
class InterviewerAgent:
|
||||
"""AI agent responsible for conducting Socratic interviews."""
|
||||
|
||||
def __init__(self, vector_service: VectorService):
|
||||
self.llm = ChatGoogleGenerativeAI(
|
||||
model=settings.LLM_FLASH_MODEL,
|
||||
google_api_key=settings.GOOGLE_API_KEY,
|
||||
temperature=0.7
|
||||
)
|
||||
self.vector_service = vector_service
|
||||
self.output_parser = StrOutputParser()
|
||||
|
||||
# Create the chain
|
||||
system_prompt = prompt_loader.get_interviewer_prompt()
|
||||
self.prompt_template = ChatPromptTemplate.from_messages([
|
||||
("system", system_prompt),
|
||||
("human", "{user_input}")
|
||||
])
|
||||
|
||||
self.chain = self.prompt_template | self.llm | self.output_parser
|
||||
|
||||
async def generate_response(
|
||||
self,
|
||||
transcript: List[Dict[str, str]],
|
||||
context_query: str
|
||||
) -> Tuple[str, bool]:
|
||||
"""Generate an interview response based on conversation history and context.
|
||||
|
||||
Args:
|
||||
transcript: List of conversation messages with 'role' and 'content' keys
|
||||
context_query: Query string for retrieving relevant context
|
||||
|
||||
Returns:
|
||||
Tuple of (response_text, should_end_session)
|
||||
"""
|
||||
# Retrieve relevant context from vector store
|
||||
context_docs = await self.vector_service.semantic_search(context_query, k=3)
|
||||
context_text = self._format_context(context_docs)
|
||||
|
||||
# Format the conversation history
|
||||
conversation_text = self._format_transcript(transcript)
|
||||
|
||||
# Get the latest user input
|
||||
user_input = transcript[-1]["content"] if transcript else ""
|
||||
|
||||
# Generate response
|
||||
response = await self.chain.ainvoke({
|
||||
"retrieved_context": context_text,
|
||||
"user_input": f"Conversation so far:\n{conversation_text}\n\nLatest message: {user_input}"
|
||||
})
|
||||
|
||||
# Check for session termination
|
||||
should_end = "[END_SESSION]" in response
|
||||
if should_end:
|
||||
# Remove the termination token from the response
|
||||
response = response.replace("[END_SESSION]", "").strip()
|
||||
if not response:
|
||||
response = "Thank you for this thoughtful conversation! I believe we've explored your ideas thoroughly."
|
||||
|
||||
return response, should_end
|
||||
|
||||
def _format_context(self, context_docs) -> str:
|
||||
"""Format context documents for the prompt."""
|
||||
if not context_docs:
|
||||
return "No relevant context from previous conversations."
|
||||
|
||||
formatted_context = []
|
||||
for doc in context_docs:
|
||||
title = doc.metadata.get("title", "Untitled")
|
||||
content = doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content
|
||||
formatted_context.append(f"• {title}: {content}")
|
||||
|
||||
return "\n".join(formatted_context)
|
||||
|
||||
def _format_transcript(self, transcript: List[Dict[str, str]]) -> str:
|
||||
"""Format conversation transcript for the prompt."""
|
||||
if not transcript:
|
||||
return "This is the start of the conversation."
|
||||
|
||||
formatted_messages = []
|
||||
for message in transcript[-6:]: # Last 6 messages for context
|
||||
role = "User" if message["role"] == "user" else "Assistant"
|
||||
formatted_messages.append(f"{role}: {message['content']}")
|
||||
|
||||
return "\n".join(formatted_messages)
|
||||
113
app/services/synthesizer.py
Normal file
113
app/services/synthesizer.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.prompt_loader import prompt_loader
|
||||
from app.data.models import SegmentationResult, LinkingResult, Note
|
||||
from app.services.vector import VectorService
|
||||
|
||||
|
||||
class SynthesizerAgent:
|
||||
"""AI agent responsible for synthesizing conversations into Zettels and creating links."""
|
||||
|
||||
def __init__(self, vector_service: VectorService):
|
||||
self.llm = ChatGoogleGenerativeAI(
|
||||
model=settings.LLM_PRO_MODEL,
|
||||
google_api_key=settings.GOOGLE_API_KEY,
|
||||
temperature=0.3 # Lower temperature for more consistent synthesis
|
||||
)
|
||||
self.vector_service = vector_service
|
||||
|
||||
async def segment_transcript(self, transcript: List[Dict[str, str]]) -> SegmentationResult:
|
||||
"""Segment a conversation transcript into atomic Zettels.
|
||||
|
||||
Args:
|
||||
transcript: List of conversation messages with 'role' and 'content' keys
|
||||
|
||||
Returns:
|
||||
SegmentationResult containing extracted Zettels
|
||||
"""
|
||||
# Format transcript for analysis
|
||||
transcript_text = self._format_transcript_for_segmentation(transcript)
|
||||
|
||||
# Create segmentation chain
|
||||
segmentation_prompt = prompt_loader.get_segmentation_prompt()
|
||||
prompt_template = ChatPromptTemplate.from_template(segmentation_prompt)
|
||||
|
||||
chain = prompt_template | self.llm.with_structured_output(
|
||||
schema=SegmentationResult,
|
||||
method="json_mode"
|
||||
)
|
||||
|
||||
# Generate segmentation
|
||||
result = await chain.ainvoke({"transcript": transcript_text})
|
||||
return result
|
||||
|
||||
async def generate_links(self, new_note: Note, neighbors: List[Document]) -> LinkingResult:
|
||||
"""Generate semantic links between a new note and existing notes.
|
||||
|
||||
Args:
|
||||
new_note: The newly created note to link
|
||||
neighbors: List of potentially related notes from vector search
|
||||
|
||||
Returns:
|
||||
LinkingResult containing identified links
|
||||
"""
|
||||
# Format neighbors for analysis
|
||||
neighbors_text = self._format_neighbors(neighbors)
|
||||
|
||||
# Create linking chain
|
||||
linking_prompt = prompt_loader.get_linking_prompt()
|
||||
prompt_template = ChatPromptTemplate.from_template(linking_prompt)
|
||||
|
||||
chain = prompt_template | self.llm.with_structured_output(
|
||||
schema=LinkingResult,
|
||||
method="json_mode"
|
||||
)
|
||||
|
||||
# Generate links
|
||||
result = await chain.ainvoke({
|
||||
"new_note_title": new_note.title,
|
||||
"new_note_content": new_note.content,
|
||||
"new_note_tags": ", ".join(new_note.tags),
|
||||
"neighbors": neighbors_text
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def _format_transcript_for_segmentation(self, transcript: List[Dict[str, str]]) -> str:
|
||||
"""Format transcript for segmentation analysis."""
|
||||
if not transcript:
|
||||
return "Empty conversation transcript."
|
||||
|
||||
formatted_messages = []
|
||||
for message in transcript:
|
||||
role = "User" if message["role"] == "user" else "AI Interviewer"
|
||||
content = message["content"]
|
||||
formatted_messages.append(f"{role}: {content}")
|
||||
|
||||
return "\n\n".join(formatted_messages)
|
||||
|
||||
def _format_neighbors(self, neighbors: List[Document]) -> str:
|
||||
"""Format neighboring documents for linking analysis."""
|
||||
if not neighbors:
|
||||
return "No existing notes found for comparison."
|
||||
|
||||
formatted_neighbors = []
|
||||
for i, doc in enumerate(neighbors, 1):
|
||||
title = doc.metadata.get("title", f"Note {i}")
|
||||
content = doc.page_content
|
||||
tags = doc.metadata.get("tags", "")
|
||||
similarity = doc.metadata.get("similarity_score", 0.0)
|
||||
|
||||
neighbor_text = f"""Note {i}: {title}
|
||||
Content: {content}
|
||||
Tags: {tags}
|
||||
Similarity Score: {similarity:.3f}"""
|
||||
|
||||
formatted_neighbors.append(neighbor_text)
|
||||
|
||||
return "\n\n" + "\n\n".join(formatted_neighbors)
|
||||
200
tests/test_interviewer_agent.py
Normal file
200
tests/test_interviewer_agent.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Test interviewer agent functionality."""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
from app.services.interviewer import InterviewerAgent
|
||||
from app.services.vector import VectorService
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
class TestInterviewerAgent:
|
||||
"""Test InterviewerAgent functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_service(self):
|
||||
"""Create a mock vector service."""
|
||||
mock_service = AsyncMock(spec=VectorService)
|
||||
mock_service.semantic_search = AsyncMock(return_value=[
|
||||
Document(
|
||||
page_content="AI ethics is important for responsible development.",
|
||||
metadata={"title": "AI Ethics", "similarity_score": 0.8}
|
||||
)
|
||||
])
|
||||
return mock_service
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_response(self):
|
||||
"""Mock LLM response."""
|
||||
return "What specific aspects of AI ethics are you most concerned about?"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interviewer_initialization(self, mock_vector_service):
|
||||
"""Test that InterviewerAgent initializes correctly."""
|
||||
with patch('app.services.interviewer.ChatGoogleGenerativeAI') as mock_llm:
|
||||
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
||||
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt {retrieved_context}"
|
||||
|
||||
agent = InterviewerAgent(mock_vector_service)
|
||||
assert agent.vector_service is mock_vector_service
|
||||
assert agent.llm is not None
|
||||
assert agent.chain is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_response_basic(self, mock_vector_service, mock_llm_response):
|
||||
"""Test basic response generation."""
|
||||
with patch('app.services.interviewer.ChatGoogleGenerativeAI') as mock_llm_class:
|
||||
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
||||
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt {retrieved_context}"
|
||||
|
||||
# Mock the LLM chain
|
||||
mock_llm = MagicMock()
|
||||
mock_llm_class.return_value = mock_llm
|
||||
|
||||
# Mock the chain invocation
|
||||
agent = InterviewerAgent(mock_vector_service)
|
||||
agent.chain = AsyncMock()
|
||||
agent.chain.ainvoke = AsyncMock(return_value=mock_llm_response)
|
||||
|
||||
transcript = [{"role": "user", "content": "I want to discuss AI ethics"}]
|
||||
response, should_end = await agent.generate_response(transcript, "AI ethics")
|
||||
|
||||
assert response == mock_llm_response
|
||||
assert should_end is False
|
||||
assert mock_vector_service.semantic_search.called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_response_with_end_session(self, mock_vector_service):
|
||||
"""Test response generation with session termination."""
|
||||
with patch('app.services.interviewer.ChatGoogleGenerativeAI') as mock_llm_class:
|
||||
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
||||
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt {retrieved_context}"
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm_class.return_value = mock_llm
|
||||
|
||||
agent = InterviewerAgent(mock_vector_service)
|
||||
agent.chain = AsyncMock()
|
||||
agent.chain.ainvoke = AsyncMock(
|
||||
return_value="Great insights! [END_SESSION] Thank you for the conversation."
|
||||
)
|
||||
|
||||
transcript = [{"role": "user", "content": "I think we've covered everything"}]
|
||||
response, should_end = await agent.generate_response(transcript, "conclusion")
|
||||
|
||||
assert should_end is True
|
||||
assert "[END_SESSION]" not in response
|
||||
assert "Great insights!" in response
|
||||
assert "Thank you" in response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_response_empty_transcript(self, mock_vector_service, mock_llm_response):
|
||||
"""Test response generation with empty transcript."""
|
||||
with patch('app.services.interviewer.ChatGoogleGenerativeAI') as mock_llm_class:
|
||||
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
||||
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt {retrieved_context}"
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm_class.return_value = mock_llm
|
||||
|
||||
agent = InterviewerAgent(mock_vector_service)
|
||||
agent.chain = AsyncMock()
|
||||
agent.chain.ainvoke = AsyncMock(return_value=mock_llm_response)
|
||||
|
||||
response, should_end = await agent.generate_response([], "test query")
|
||||
|
||||
assert response == mock_llm_response
|
||||
assert should_end is False
|
||||
|
||||
def test_format_context_with_documents(self):
|
||||
"""Test context formatting with documents."""
|
||||
mock_vector_service = MagicMock()
|
||||
|
||||
with patch('app.services.interviewer.ChatGoogleGenerativeAI'):
|
||||
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
||||
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt"
|
||||
|
||||
agent = InterviewerAgent(mock_vector_service)
|
||||
|
||||
docs = [
|
||||
Document(
|
||||
page_content="Long content " * 50, # Will be truncated
|
||||
metadata={"title": "Test Doc 1"}
|
||||
),
|
||||
Document(
|
||||
page_content="Short content",
|
||||
metadata={"title": "Test Doc 2"}
|
||||
)
|
||||
]
|
||||
|
||||
result = agent._format_context(docs)
|
||||
assert "Test Doc 1" in result
|
||||
assert "Test Doc 2" in result
|
||||
assert "..." in result # Truncation indicator
|
||||
|
||||
def test_format_context_empty(self):
|
||||
"""Test context formatting with no documents."""
|
||||
mock_vector_service = MagicMock()
|
||||
|
||||
with patch('app.services.interviewer.ChatGoogleGenerativeAI'):
|
||||
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
||||
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt"
|
||||
|
||||
agent = InterviewerAgent(mock_vector_service)
|
||||
result = agent._format_context([])
|
||||
assert "No relevant context" in result
|
||||
|
||||
def test_format_transcript(self):
|
||||
"""Test transcript formatting."""
|
||||
mock_vector_service = MagicMock()
|
||||
|
||||
with patch('app.services.interviewer.ChatGoogleGenerativeAI'):
|
||||
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
||||
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt"
|
||||
|
||||
agent = InterviewerAgent(mock_vector_service)
|
||||
|
||||
transcript = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "Let's discuss AI"}
|
||||
]
|
||||
|
||||
result = agent._format_transcript(transcript)
|
||||
assert "User: Hello" in result
|
||||
assert "Assistant: Hi there!" in result
|
||||
assert "User: Let's discuss AI" in result
|
||||
|
||||
def test_format_transcript_empty(self):
|
||||
"""Test transcript formatting with empty list."""
|
||||
mock_vector_service = MagicMock()
|
||||
|
||||
with patch('app.services.interviewer.ChatGoogleGenerativeAI'):
|
||||
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
||||
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt"
|
||||
|
||||
agent = InterviewerAgent(mock_vector_service)
|
||||
result = agent._format_transcript([])
|
||||
assert "start of the conversation" in result
|
||||
|
||||
def test_format_transcript_long_conversation(self):
|
||||
"""Test transcript formatting limits to last 6 messages."""
|
||||
mock_vector_service = MagicMock()
|
||||
|
||||
with patch('app.services.interviewer.ChatGoogleGenerativeAI'):
|
||||
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
||||
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt"
|
||||
|
||||
agent = InterviewerAgent(mock_vector_service)
|
||||
|
||||
# Create 10 messages
|
||||
transcript = []
|
||||
for i in range(10):
|
||||
transcript.append({"role": "user", "content": f"Message {i}"})
|
||||
|
||||
result = agent._format_transcript(transcript)
|
||||
|
||||
# Should only contain last 6 messages
|
||||
assert "Message 4" in result # Should be included (10-6=4)
|
||||
assert "Message 9" in result # Should be included (last message)
|
||||
assert "Message 0" not in result # Should be excluded
|
||||
assert "Message 3" not in result # Should be excluded
|
||||
99
tests/test_prompt_loader.py
Normal file
99
tests/test_prompt_loader.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Test prompt loading functionality."""
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.prompt_loader import PromptLoader, prompt_loader
|
||||
|
||||
|
||||
class TestPromptLoader:
|
||||
"""Test PromptLoader functionality."""
|
||||
|
||||
def test_prompt_loader_initialization(self):
|
||||
"""Test that PromptLoader initializes correctly."""
|
||||
loader = PromptLoader()
|
||||
assert loader.prompts_dir.exists()
|
||||
assert loader.prompts_dir.name == "prompts"
|
||||
|
||||
def test_load_existing_prompt(self):
|
||||
"""Test loading an existing prompt file."""
|
||||
loader = PromptLoader()
|
||||
interviewer_prompt = loader.load_prompt("interviewer")
|
||||
|
||||
assert isinstance(interviewer_prompt, str)
|
||||
assert len(interviewer_prompt) > 0
|
||||
assert "Socratic interviewer" in interviewer_prompt
|
||||
assert "{retrieved_context}" in interviewer_prompt
|
||||
assert "[END_SESSION]" in interviewer_prompt
|
||||
|
||||
def test_load_nonexistent_prompt(self):
|
||||
"""Test loading a non-existent prompt file."""
|
||||
loader = PromptLoader()
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
loader.load_prompt("nonexistent_prompt")
|
||||
|
||||
def test_get_interviewer_prompt(self):
|
||||
"""Test the interviewer prompt getter."""
|
||||
loader = PromptLoader()
|
||||
prompt = loader.get_interviewer_prompt()
|
||||
|
||||
assert isinstance(prompt, str)
|
||||
assert "interviewer" in prompt.lower()
|
||||
assert "question" in prompt.lower()
|
||||
|
||||
def test_get_segmentation_prompt(self):
|
||||
"""Test the segmentation prompt getter."""
|
||||
loader = PromptLoader()
|
||||
prompt = loader.get_segmentation_prompt()
|
||||
|
||||
assert isinstance(prompt, str)
|
||||
assert "transcript" in prompt.lower()
|
||||
assert "zettel" in prompt.lower()
|
||||
assert "{transcript}" in prompt
|
||||
|
||||
def test_get_linking_prompt(self):
|
||||
"""Test the linking prompt getter."""
|
||||
loader = PromptLoader()
|
||||
prompt = loader.get_linking_prompt()
|
||||
|
||||
assert isinstance(prompt, str)
|
||||
assert "relationship" in prompt.lower()
|
||||
assert "semantic" in prompt.lower()
|
||||
assert "{new_note_title}" in prompt
|
||||
assert "{neighbors}" in prompt
|
||||
|
||||
def test_caching_functionality(self):
|
||||
"""Test that prompts are cached."""
|
||||
loader = PromptLoader()
|
||||
|
||||
# Load the same prompt twice
|
||||
prompt1 = loader.load_prompt("interviewer")
|
||||
prompt2 = loader.load_prompt("interviewer")
|
||||
|
||||
# They should be the same object (cached)
|
||||
assert prompt1 is prompt2
|
||||
|
||||
def test_global_prompt_loader_instance(self):
|
||||
"""Test that the global prompt_loader instance works."""
|
||||
interviewer_prompt = prompt_loader.get_interviewer_prompt()
|
||||
segmentation_prompt = prompt_loader.get_segmentation_prompt()
|
||||
linking_prompt = prompt_loader.get_linking_prompt()
|
||||
|
||||
assert all(isinstance(p, str) and len(p) > 0 for p in [
|
||||
interviewer_prompt, segmentation_prompt, linking_prompt
|
||||
])
|
||||
|
||||
def test_all_required_prompts_exist(self):
|
||||
"""Test that all required prompt files exist."""
|
||||
loader = PromptLoader()
|
||||
prompts_dir = loader.prompts_dir
|
||||
|
||||
required_prompts = ["interviewer.txt", "segmentation.txt", "linking.txt"]
|
||||
|
||||
for prompt_file in required_prompts:
|
||||
prompt_path = prompts_dir / prompt_file
|
||||
assert prompt_path.exists(), f"Required prompt file missing: {prompt_file}"
|
||||
|
||||
# Test that the file has content
|
||||
content = prompt_path.read_text()
|
||||
assert len(content.strip()) > 0, f"Prompt file is empty: {prompt_file}"
|
||||
162
tests/test_synthesizer_agent.py
Normal file
162
tests/test_synthesizer_agent.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""Test synthesizer agent functionality - simplified version."""
|
||||
import pytest
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
from datetime import datetime
|
||||
|
||||
from app.services.synthesizer import SynthesizerAgent
|
||||
from app.services.vector import VectorService
|
||||
from app.data.models import Note
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
class TestSynthesizerAgent:
|
||||
"""Test SynthesizerAgent functionality with simplified mocking."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_service(self):
|
||||
"""Create a mock vector service."""
|
||||
return MagicMock(spec=VectorService)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_transcript(self):
|
||||
"""Create a sample conversation transcript."""
|
||||
return [
|
||||
{"role": "user", "content": "I want to explore AI ethics"},
|
||||
{"role": "assistant", "content": "What aspects of AI ethics concern you most?"},
|
||||
{"role": "user", "content": "I'm worried about bias in algorithms"}
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def sample_note(self):
|
||||
"""Create a sample note for linking."""
|
||||
return Note(
|
||||
id=uuid.uuid4(),
|
||||
title="Algorithmic Bias",
|
||||
content="AI systems can perpetuate discrimination when trained on biased datasets.",
|
||||
tags=["ai", "bias", "ethics"],
|
||||
session_id=uuid.uuid4(),
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_neighbors(self):
|
||||
"""Create sample neighboring documents."""
|
||||
return [
|
||||
Document(
|
||||
page_content="Machine learning fairness requires careful consideration.",
|
||||
metadata={"title": "ML Fairness", "tags": "ml,fairness", "similarity_score": 0.8}
|
||||
),
|
||||
Document(
|
||||
page_content="Privacy-preserving AI techniques protect user data.",
|
||||
metadata={"title": "Privacy in AI", "tags": "privacy,ai", "similarity_score": 0.6}
|
||||
)
|
||||
]
|
||||
|
||||
def test_synthesizer_initialization(self, mock_vector_service):
|
||||
"""Test that SynthesizerAgent initializes correctly."""
|
||||
with patch('app.services.synthesizer.ChatGoogleGenerativeAI') as mock_llm:
|
||||
agent = SynthesizerAgent(mock_vector_service)
|
||||
assert agent.vector_service is mock_vector_service
|
||||
assert agent.llm is not None
|
||||
# Verify it uses the Pro model with lower temperature
|
||||
mock_llm.assert_called_once()
|
||||
call_args = mock_llm.call_args
|
||||
assert call_args.kwargs['temperature'] == 0.3
|
||||
|
||||
def test_format_transcript_for_segmentation(self, mock_vector_service, sample_transcript):
|
||||
"""Test transcript formatting for segmentation."""
|
||||
with patch('app.services.synthesizer.ChatGoogleGenerativeAI'):
|
||||
agent = SynthesizerAgent(mock_vector_service)
|
||||
|
||||
result = agent._format_transcript_for_segmentation(sample_transcript)
|
||||
|
||||
assert "User:" in result
|
||||
assert "AI Interviewer:" in result
|
||||
assert "AI ethics" in result
|
||||
assert "bias" in result.lower()
|
||||
|
||||
def test_format_transcript_for_segmentation_empty(self, mock_vector_service):
|
||||
"""Test transcript formatting with empty transcript."""
|
||||
with patch('app.services.synthesizer.ChatGoogleGenerativeAI'):
|
||||
agent = SynthesizerAgent(mock_vector_service)
|
||||
|
||||
result = agent._format_transcript_for_segmentation([])
|
||||
assert "Empty conversation transcript" in result
|
||||
|
||||
def test_format_neighbors(self, mock_vector_service, sample_neighbors):
|
||||
"""Test neighbor formatting for linking."""
|
||||
with patch('app.services.synthesizer.ChatGoogleGenerativeAI'):
|
||||
agent = SynthesizerAgent(mock_vector_service)
|
||||
|
||||
result = agent._format_neighbors(sample_neighbors)
|
||||
|
||||
assert "Note 1:" in result
|
||||
assert "Note 2:" in result
|
||||
assert "ML Fairness" in result
|
||||
assert "Privacy in AI" in result
|
||||
assert "Similarity Score:" in result
|
||||
assert "0.800" in result # Formatted similarity score
|
||||
|
||||
def test_format_neighbors_empty(self, mock_vector_service):
|
||||
"""Test neighbor formatting with no neighbors."""
|
||||
with patch('app.services.synthesizer.ChatGoogleGenerativeAI'):
|
||||
agent = SynthesizerAgent(mock_vector_service)
|
||||
|
||||
result = agent._format_neighbors([])
|
||||
assert "No existing notes found" in result
|
||||
|
||||
def test_format_neighbors_with_metadata(self, mock_vector_service):
|
||||
"""Test neighbor formatting handles missing metadata gracefully."""
|
||||
with patch('app.services.synthesizer.ChatGoogleGenerativeAI'):
|
||||
agent = SynthesizerAgent(mock_vector_service)
|
||||
|
||||
neighbors = [
|
||||
Document(
|
||||
page_content="Content without much metadata",
|
||||
metadata={} # Missing title, tags, similarity_score
|
||||
)
|
||||
]
|
||||
|
||||
result = agent._format_neighbors(neighbors)
|
||||
assert "Note 1:" in result
|
||||
assert "Content without much metadata" in result
|
||||
assert "Similarity Score: 0.000" in result # Default value
|
||||
|
||||
def test_transcript_formatting_preserves_conversation_flow(self, mock_vector_service):
|
||||
"""Test that transcript formatting preserves the conversation flow."""
|
||||
with patch('app.services.synthesizer.ChatGoogleGenerativeAI'):
|
||||
agent = SynthesizerAgent(mock_vector_service)
|
||||
|
||||
transcript = [
|
||||
{"role": "user", "content": "First message"},
|
||||
{"role": "assistant", "content": "First response"},
|
||||
{"role": "user", "content": "Second message"},
|
||||
{"role": "assistant", "content": "Second response"}
|
||||
]
|
||||
|
||||
result = agent._format_transcript_for_segmentation(transcript)
|
||||
|
||||
# Check order is preserved
|
||||
lines = result.split('\n\n')
|
||||
assert len(lines) == 4
|
||||
assert "User: First message" in lines[0]
|
||||
assert "AI Interviewer: First response" in lines[1]
|
||||
assert "User: Second message" in lines[2]
|
||||
assert "AI Interviewer: Second response" in lines[3]
|
||||
|
||||
def test_agent_uses_correct_model_config(self, mock_vector_service):
|
||||
"""Test that the agent is configured with the correct model settings."""
|
||||
with patch('app.services.synthesizer.ChatGoogleGenerativeAI') as mock_llm_class:
|
||||
with patch('app.services.synthesizer.settings') as mock_settings:
|
||||
mock_settings.LLM_PRO_MODEL = "test-pro-model"
|
||||
mock_settings.GOOGLE_API_KEY = "test-key"
|
||||
|
||||
agent = SynthesizerAgent(mock_vector_service)
|
||||
|
||||
# Verify the LLM was initialized with correct parameters
|
||||
mock_llm_class.assert_called_once_with(
|
||||
model="test-pro-model",
|
||||
google_api_key="test-key",
|
||||
temperature=0.3
|
||||
)
|
||||
121
tests/test_vector_service.py
Normal file
121
tests/test_vector_service.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Test vector service functionality."""
|
||||
import pytest
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from app.services.vector import VectorService
|
||||
from app.data.models import Note
|
||||
|
||||
|
||||
class TestVectorService:
|
||||
"""Test VectorService operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_notes(self):
|
||||
"""Create sample notes for testing."""
|
||||
session_id = uuid.uuid4()
|
||||
return [
|
||||
Note(
|
||||
title="AI Ethics",
|
||||
content="Artificial intelligence systems must be designed with ethical considerations in mind.",
|
||||
tags=["ai", "ethics", "technology"],
|
||||
session_id=session_id
|
||||
),
|
||||
Note(
|
||||
title="Machine Learning Bias",
|
||||
content="Bias in machine learning models can perpetuate unfair discrimination.",
|
||||
tags=["ml", "bias", "fairness"],
|
||||
session_id=session_id
|
||||
)
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vector_service_initialization(self):
|
||||
"""Test that VectorService initializes correctly."""
|
||||
with patch('app.services.vector.GoogleGenerativeAIEmbeddings'):
|
||||
with patch('app.services.vector.chromadb.PersistentClient'):
|
||||
service = VectorService()
|
||||
assert service is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_notes_empty_list(self):
|
||||
"""Test adding empty list of notes."""
|
||||
with patch('app.services.vector.GoogleGenerativeAIEmbeddings'):
|
||||
with patch('app.services.vector.chromadb.PersistentClient'):
|
||||
service = VectorService()
|
||||
await service.add_notes([]) # Should not raise an error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_notes_with_mocked_embeddings(self, sample_notes):
|
||||
"""Test adding notes with mocked embedding service."""
|
||||
mock_embeddings = AsyncMock()
|
||||
mock_embeddings.aembed_documents = AsyncMock(return_value=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
with patch('app.services.vector.GoogleGenerativeAIEmbeddings', return_value=mock_embeddings):
|
||||
with patch('app.services.vector.chromadb.PersistentClient', return_value=mock_client):
|
||||
service = VectorService()
|
||||
await service.add_notes(sample_notes)
|
||||
|
||||
# Verify embeddings were called
|
||||
mock_embeddings.aembed_documents.assert_called_once()
|
||||
# Verify collection.add was called
|
||||
mock_collection.add.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_semantic_search_empty_collection(self):
|
||||
"""Test semantic search on empty collection."""
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.count.return_value = 0
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
with patch('app.services.vector.GoogleGenerativeAIEmbeddings'):
|
||||
with patch('app.services.vector.chromadb.PersistentClient', return_value=mock_client):
|
||||
service = VectorService()
|
||||
results = await service.semantic_search("test query")
|
||||
assert results == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_semantic_search_with_results(self):
|
||||
"""Test semantic search with mocked results."""
|
||||
mock_embeddings = AsyncMock()
|
||||
mock_embeddings.aembed_query = AsyncMock(return_value=[0.1, 0.2, 0.3])
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.count.return_value = 2
|
||||
mock_collection.query.return_value = {
|
||||
"documents": [["AI ethics content", "ML bias content"]],
|
||||
"metadatas": [[{"title": "AI Ethics", "tags": "ai,ethics"}, {"title": "ML Bias", "tags": "ml,bias"}]],
|
||||
"distances": [[0.2, 0.4]]
|
||||
}
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
with patch('app.services.vector.GoogleGenerativeAIEmbeddings', return_value=mock_embeddings):
|
||||
with patch('app.services.vector.chromadb.PersistentClient', return_value=mock_client):
|
||||
service = VectorService()
|
||||
results = await service.semantic_search("ethics")
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0].page_content == "AI ethics content"
|
||||
assert results[0].metadata["title"] == "AI Ethics"
|
||||
assert results[0].metadata["similarity_score"] == 0.8 # 1 - 0.2
|
||||
|
||||
def test_reset_collection(self):
|
||||
"""Test collection reset functionality."""
|
||||
mock_collection = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
with patch('app.services.vector.GoogleGenerativeAIEmbeddings'):
|
||||
with patch('app.services.vector.chromadb.PersistentClient', return_value=mock_client):
|
||||
service = VectorService()
|
||||
service.reset_collection() # Should not raise an error
|
||||
mock_client.delete_collection.assert_called_once_with("skytalk_notes")
|
||||
Reference in New Issue
Block a user