- Create InterviewerAgent with Socratic questioning and RAG context - Build SynthesizerAgent for transcript segmentation and link generation - Integrate Google Gemini models (Flash for interviewing, Pro for synthesis) - Add structured output parsing for Zettel extraction and linking - Implement session termination detection with [END_SESSION] token - Add conversation context formatting and similarity-based neighbor filtering - Add vector service tests with mocked ChromaDB and embeddings - Test interviewer agent RAG conversations and session termination - Test synthesizer agent transcript formatting and neighbor analysis - Add prompt loader tests for external prompt system - Test all repository CRUD operations and database transactions
200 lines
9.4 KiB
Python
200 lines
9.4 KiB
Python
"""Test interviewer agent functionality."""
|
|
import pytest
|
|
from unittest.mock import AsyncMock, patch, MagicMock
|
|
|
|
from app.services.interviewer import InterviewerAgent
|
|
from app.services.vector import VectorService
|
|
from langchain_core.documents import Document
|
|
|
|
|
|
class TestInterviewerAgent:
|
|
"""Test InterviewerAgent functionality."""
|
|
|
|
@pytest.fixture
|
|
def mock_vector_service(self):
|
|
"""Create a mock vector service."""
|
|
mock_service = AsyncMock(spec=VectorService)
|
|
mock_service.semantic_search = AsyncMock(return_value=[
|
|
Document(
|
|
page_content="AI ethics is important for responsible development.",
|
|
metadata={"title": "AI Ethics", "similarity_score": 0.8}
|
|
)
|
|
])
|
|
return mock_service
|
|
|
|
@pytest.fixture
|
|
def mock_llm_response(self):
|
|
"""Mock LLM response."""
|
|
return "What specific aspects of AI ethics are you most concerned about?"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_interviewer_initialization(self, mock_vector_service):
|
|
"""Test that InterviewerAgent initializes correctly."""
|
|
with patch('app.services.interviewer.ChatGoogleGenerativeAI') as mock_llm:
|
|
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
|
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt {retrieved_context}"
|
|
|
|
agent = InterviewerAgent(mock_vector_service)
|
|
assert agent.vector_service is mock_vector_service
|
|
assert agent.llm is not None
|
|
assert agent.chain is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_response_basic(self, mock_vector_service, mock_llm_response):
|
|
"""Test basic response generation."""
|
|
with patch('app.services.interviewer.ChatGoogleGenerativeAI') as mock_llm_class:
|
|
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
|
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt {retrieved_context}"
|
|
|
|
# Mock the LLM chain
|
|
mock_llm = MagicMock()
|
|
mock_llm_class.return_value = mock_llm
|
|
|
|
# Mock the chain invocation
|
|
agent = InterviewerAgent(mock_vector_service)
|
|
agent.chain = AsyncMock()
|
|
agent.chain.ainvoke = AsyncMock(return_value=mock_llm_response)
|
|
|
|
transcript = [{"role": "user", "content": "I want to discuss AI ethics"}]
|
|
response, should_end = await agent.generate_response(transcript, "AI ethics")
|
|
|
|
assert response == mock_llm_response
|
|
assert should_end is False
|
|
assert mock_vector_service.semantic_search.called
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_response_with_end_session(self, mock_vector_service):
|
|
"""Test response generation with session termination."""
|
|
with patch('app.services.interviewer.ChatGoogleGenerativeAI') as mock_llm_class:
|
|
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
|
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt {retrieved_context}"
|
|
|
|
mock_llm = MagicMock()
|
|
mock_llm_class.return_value = mock_llm
|
|
|
|
agent = InterviewerAgent(mock_vector_service)
|
|
agent.chain = AsyncMock()
|
|
agent.chain.ainvoke = AsyncMock(
|
|
return_value="Great insights! [END_SESSION] Thank you for the conversation."
|
|
)
|
|
|
|
transcript = [{"role": "user", "content": "I think we've covered everything"}]
|
|
response, should_end = await agent.generate_response(transcript, "conclusion")
|
|
|
|
assert should_end is True
|
|
assert "[END_SESSION]" not in response
|
|
assert "Great insights!" in response
|
|
assert "Thank you" in response
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_response_empty_transcript(self, mock_vector_service, mock_llm_response):
|
|
"""Test response generation with empty transcript."""
|
|
with patch('app.services.interviewer.ChatGoogleGenerativeAI') as mock_llm_class:
|
|
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
|
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt {retrieved_context}"
|
|
|
|
mock_llm = MagicMock()
|
|
mock_llm_class.return_value = mock_llm
|
|
|
|
agent = InterviewerAgent(mock_vector_service)
|
|
agent.chain = AsyncMock()
|
|
agent.chain.ainvoke = AsyncMock(return_value=mock_llm_response)
|
|
|
|
response, should_end = await agent.generate_response([], "test query")
|
|
|
|
assert response == mock_llm_response
|
|
assert should_end is False
|
|
|
|
def test_format_context_with_documents(self):
|
|
"""Test context formatting with documents."""
|
|
mock_vector_service = MagicMock()
|
|
|
|
with patch('app.services.interviewer.ChatGoogleGenerativeAI'):
|
|
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
|
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt"
|
|
|
|
agent = InterviewerAgent(mock_vector_service)
|
|
|
|
docs = [
|
|
Document(
|
|
page_content="Long content " * 50, # Will be truncated
|
|
metadata={"title": "Test Doc 1"}
|
|
),
|
|
Document(
|
|
page_content="Short content",
|
|
metadata={"title": "Test Doc 2"}
|
|
)
|
|
]
|
|
|
|
result = agent._format_context(docs)
|
|
assert "Test Doc 1" in result
|
|
assert "Test Doc 2" in result
|
|
assert "..." in result # Truncation indicator
|
|
|
|
def test_format_context_empty(self):
|
|
"""Test context formatting with no documents."""
|
|
mock_vector_service = MagicMock()
|
|
|
|
with patch('app.services.interviewer.ChatGoogleGenerativeAI'):
|
|
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
|
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt"
|
|
|
|
agent = InterviewerAgent(mock_vector_service)
|
|
result = agent._format_context([])
|
|
assert "No relevant context" in result
|
|
|
|
def test_format_transcript(self):
|
|
"""Test transcript formatting."""
|
|
mock_vector_service = MagicMock()
|
|
|
|
with patch('app.services.interviewer.ChatGoogleGenerativeAI'):
|
|
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
|
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt"
|
|
|
|
agent = InterviewerAgent(mock_vector_service)
|
|
|
|
transcript = [
|
|
{"role": "user", "content": "Hello"},
|
|
{"role": "assistant", "content": "Hi there!"},
|
|
{"role": "user", "content": "Let's discuss AI"}
|
|
]
|
|
|
|
result = agent._format_transcript(transcript)
|
|
assert "User: Hello" in result
|
|
assert "Assistant: Hi there!" in result
|
|
assert "User: Let's discuss AI" in result
|
|
|
|
def test_format_transcript_empty(self):
|
|
"""Test transcript formatting with empty list."""
|
|
mock_vector_service = MagicMock()
|
|
|
|
with patch('app.services.interviewer.ChatGoogleGenerativeAI'):
|
|
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
|
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt"
|
|
|
|
agent = InterviewerAgent(mock_vector_service)
|
|
result = agent._format_transcript([])
|
|
assert "start of the conversation" in result
|
|
|
|
def test_format_transcript_long_conversation(self):
|
|
"""Test transcript formatting limits to last 6 messages."""
|
|
mock_vector_service = MagicMock()
|
|
|
|
with patch('app.services.interviewer.ChatGoogleGenerativeAI'):
|
|
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
|
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt"
|
|
|
|
agent = InterviewerAgent(mock_vector_service)
|
|
|
|
# Create 10 messages
|
|
transcript = []
|
|
for i in range(10):
|
|
transcript.append({"role": "user", "content": f"Message {i}"})
|
|
|
|
result = agent._format_transcript(transcript)
|
|
|
|
# Should only contain last 6 messages
|
|
assert "Message 4" in result # Should be included (10-6=4)
|
|
assert "Message 9" in result # Should be included (last message)
|
|
assert "Message 0" not in result # Should be excluded
|
|
assert "Message 3" not in result # Should be excluded |