feat: implement AI agent layer with LangChain integration
- Create InterviewerAgent with Socratic questioning and RAG context - Build SynthesizerAgent for transcript segmentation and link generation - Integrate Google Gemini models (Flash for interviewing, Pro for synthesis) - Add structured output parsing for Zettel extraction and linking - Implement session termination detection with [END_SESSION] token - Add conversation context formatting and similarity-based neighbor filtering - Add vector service tests with mocked ChromaDB and embeddings - Test interviewer agent RAG conversations and session termination - Test synthesizer agent transcript formatting and neighbor analysis - Add prompt loader tests for external prompt system - Test all repository CRUD operations and database transactions
This commit is contained in:
200
tests/test_interviewer_agent.py
Normal file
200
tests/test_interviewer_agent.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Test interviewer agent functionality."""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
from app.services.interviewer import InterviewerAgent
|
||||
from app.services.vector import VectorService
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
class TestInterviewerAgent:
|
||||
"""Test InterviewerAgent functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_service(self):
|
||||
"""Create a mock vector service."""
|
||||
mock_service = AsyncMock(spec=VectorService)
|
||||
mock_service.semantic_search = AsyncMock(return_value=[
|
||||
Document(
|
||||
page_content="AI ethics is important for responsible development.",
|
||||
metadata={"title": "AI Ethics", "similarity_score": 0.8}
|
||||
)
|
||||
])
|
||||
return mock_service
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_response(self):
|
||||
"""Mock LLM response."""
|
||||
return "What specific aspects of AI ethics are you most concerned about?"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interviewer_initialization(self, mock_vector_service):
|
||||
"""Test that InterviewerAgent initializes correctly."""
|
||||
with patch('app.services.interviewer.ChatGoogleGenerativeAI') as mock_llm:
|
||||
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
||||
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt {retrieved_context}"
|
||||
|
||||
agent = InterviewerAgent(mock_vector_service)
|
||||
assert agent.vector_service is mock_vector_service
|
||||
assert agent.llm is not None
|
||||
assert agent.chain is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_response_basic(self, mock_vector_service, mock_llm_response):
|
||||
"""Test basic response generation."""
|
||||
with patch('app.services.interviewer.ChatGoogleGenerativeAI') as mock_llm_class:
|
||||
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
||||
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt {retrieved_context}"
|
||||
|
||||
# Mock the LLM chain
|
||||
mock_llm = MagicMock()
|
||||
mock_llm_class.return_value = mock_llm
|
||||
|
||||
# Mock the chain invocation
|
||||
agent = InterviewerAgent(mock_vector_service)
|
||||
agent.chain = AsyncMock()
|
||||
agent.chain.ainvoke = AsyncMock(return_value=mock_llm_response)
|
||||
|
||||
transcript = [{"role": "user", "content": "I want to discuss AI ethics"}]
|
||||
response, should_end = await agent.generate_response(transcript, "AI ethics")
|
||||
|
||||
assert response == mock_llm_response
|
||||
assert should_end is False
|
||||
assert mock_vector_service.semantic_search.called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_response_with_end_session(self, mock_vector_service):
|
||||
"""Test response generation with session termination."""
|
||||
with patch('app.services.interviewer.ChatGoogleGenerativeAI') as mock_llm_class:
|
||||
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
||||
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt {retrieved_context}"
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm_class.return_value = mock_llm
|
||||
|
||||
agent = InterviewerAgent(mock_vector_service)
|
||||
agent.chain = AsyncMock()
|
||||
agent.chain.ainvoke = AsyncMock(
|
||||
return_value="Great insights! [END_SESSION] Thank you for the conversation."
|
||||
)
|
||||
|
||||
transcript = [{"role": "user", "content": "I think we've covered everything"}]
|
||||
response, should_end = await agent.generate_response(transcript, "conclusion")
|
||||
|
||||
assert should_end is True
|
||||
assert "[END_SESSION]" not in response
|
||||
assert "Great insights!" in response
|
||||
assert "Thank you" in response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_response_empty_transcript(self, mock_vector_service, mock_llm_response):
|
||||
"""Test response generation with empty transcript."""
|
||||
with patch('app.services.interviewer.ChatGoogleGenerativeAI') as mock_llm_class:
|
||||
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
||||
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt {retrieved_context}"
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm_class.return_value = mock_llm
|
||||
|
||||
agent = InterviewerAgent(mock_vector_service)
|
||||
agent.chain = AsyncMock()
|
||||
agent.chain.ainvoke = AsyncMock(return_value=mock_llm_response)
|
||||
|
||||
response, should_end = await agent.generate_response([], "test query")
|
||||
|
||||
assert response == mock_llm_response
|
||||
assert should_end is False
|
||||
|
||||
def test_format_context_with_documents(self):
|
||||
"""Test context formatting with documents."""
|
||||
mock_vector_service = MagicMock()
|
||||
|
||||
with patch('app.services.interviewer.ChatGoogleGenerativeAI'):
|
||||
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
||||
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt"
|
||||
|
||||
agent = InterviewerAgent(mock_vector_service)
|
||||
|
||||
docs = [
|
||||
Document(
|
||||
page_content="Long content " * 50, # Will be truncated
|
||||
metadata={"title": "Test Doc 1"}
|
||||
),
|
||||
Document(
|
||||
page_content="Short content",
|
||||
metadata={"title": "Test Doc 2"}
|
||||
)
|
||||
]
|
||||
|
||||
result = agent._format_context(docs)
|
||||
assert "Test Doc 1" in result
|
||||
assert "Test Doc 2" in result
|
||||
assert "..." in result # Truncation indicator
|
||||
|
||||
def test_format_context_empty(self):
|
||||
"""Test context formatting with no documents."""
|
||||
mock_vector_service = MagicMock()
|
||||
|
||||
with patch('app.services.interviewer.ChatGoogleGenerativeAI'):
|
||||
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
||||
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt"
|
||||
|
||||
agent = InterviewerAgent(mock_vector_service)
|
||||
result = agent._format_context([])
|
||||
assert "No relevant context" in result
|
||||
|
||||
def test_format_transcript(self):
|
||||
"""Test transcript formatting."""
|
||||
mock_vector_service = MagicMock()
|
||||
|
||||
with patch('app.services.interviewer.ChatGoogleGenerativeAI'):
|
||||
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
||||
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt"
|
||||
|
||||
agent = InterviewerAgent(mock_vector_service)
|
||||
|
||||
transcript = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "Let's discuss AI"}
|
||||
]
|
||||
|
||||
result = agent._format_transcript(transcript)
|
||||
assert "User: Hello" in result
|
||||
assert "Assistant: Hi there!" in result
|
||||
assert "User: Let's discuss AI" in result
|
||||
|
||||
def test_format_transcript_empty(self):
|
||||
"""Test transcript formatting with empty list."""
|
||||
mock_vector_service = MagicMock()
|
||||
|
||||
with patch('app.services.interviewer.ChatGoogleGenerativeAI'):
|
||||
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
||||
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt"
|
||||
|
||||
agent = InterviewerAgent(mock_vector_service)
|
||||
result = agent._format_transcript([])
|
||||
assert "start of the conversation" in result
|
||||
|
||||
def test_format_transcript_long_conversation(self):
|
||||
"""Test transcript formatting limits to last 6 messages."""
|
||||
mock_vector_service = MagicMock()
|
||||
|
||||
with patch('app.services.interviewer.ChatGoogleGenerativeAI'):
|
||||
with patch('app.services.interviewer.prompt_loader') as mock_prompt_loader:
|
||||
mock_prompt_loader.get_interviewer_prompt.return_value = "Test prompt"
|
||||
|
||||
agent = InterviewerAgent(mock_vector_service)
|
||||
|
||||
# Create 10 messages
|
||||
transcript = []
|
||||
for i in range(10):
|
||||
transcript.append({"role": "user", "content": f"Message {i}"})
|
||||
|
||||
result = agent._format_transcript(transcript)
|
||||
|
||||
# Should only contain last 6 messages
|
||||
assert "Message 4" in result # Should be included (10-6=4)
|
||||
assert "Message 9" in result # Should be included (last message)
|
||||
assert "Message 0" not in result # Should be excluded
|
||||
assert "Message 3" not in result # Should be excluded
|
||||
Reference in New Issue
Block a user