LangChain Prompt Chaining
This snippet demonstrates how to effectively chain prompts using LangChain to create sophisticated AI workflows that can handle complex, multi-step reasoning tasks.
What is Prompt Chaining?
Prompt chaining is a technique where you connect multiple prompts together, where the output of one prompt becomes the input for the next. This allows you to:
- Break down complex tasks into manageable steps
- Improve reasoning quality through step-by-step processing
- Create specialized prompts for different parts of a workflow
- Build more reliable and predictable AI systems
Key Concepts
1. Sequential Chains
Execute prompts in a specific order where each step builds on the previous one.
2. Conditional Chains
Route to different prompts based on the content or results of previous steps.
3. Parallel Chains
Execute multiple prompts simultaneously and combine their results.
4. Memory Integration
Maintain context and state across multiple prompt executions.
Common Use Cases
- Content Creation: Research → Outline → Writing → Editing
- Data Analysis: Collection → Processing → Analysis → Reporting
- Decision Making: Information Gathering → Analysis → Recommendation → Action Plan
- Code Generation: Requirements → Design → Implementation → Testing
Implementation Patterns
The notebook demonstrates several practical implementations:
- Simple Sequential Chain: Basic prompt chaining for content creation
- Router Chain: Conditional routing based on input classification
- Map-Reduce Chain: Processing multiple inputs and combining results
- Conversation Chain: Maintaining context across multiple interactions
- Custom Chain: Building specialized chains for specific use cases
Best Practices
- Clear Interfaces: Define clear input/output formats between chain steps
- Error Handling: Implement robust error handling and fallback mechanisms
- Validation: Validate outputs at each step before passing to the next
- Monitoring: Track performance and quality metrics across the chain
- Modularity: Design reusable chain components
Advanced Techniques
- Dynamic Routing: Automatically determine the best chain path
- Feedback Loops: Incorporate self-correction mechanisms
- Parallel Processing: Optimize performance with concurrent execution
- Chain Composition: Combine multiple specialized chains
This approach enables building sophisticated AI applications that can handle complex workflows while maintaining reliability and predictability.
Notebook Information
# Install required packages
!pip install langchain langchain-openai langchain-core python-dotenv import os
from dotenv import load_dotenv
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableBranch
from langchain_openai import ChatOpenAI
from langchain.memory import ConversationBufferWindowMemory
from langchain.schema import BaseMessage, HumanMessage, AIMessage
# Load environment variables
load_dotenv()
# Initialize the LLM
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.7) # Step 1: Research prompt
research_prompt = ChatPromptTemplate.from_template(
"You are a research assistant. Research the topic: {topic}\n"
"Provide 3-5 key points about this topic that would be useful for writing an article.\n\n"
"Research findings:"
)
# Step 2: Article writing prompt
writing_prompt = ChatPromptTemplate.from_template(
"Based on the following research findings, write a well-structured article about {topic}.\n\n"
"Research Findings:\n{research_findings}\n\n"
"Article:"
)
# Create the chain using LCEL
research_chain = research_prompt | llm | StrOutputParser()
# Sequential chain that passes research findings to writing prompt
sequential_chain = (
{"topic": RunnablePassthrough(), "research_findings": research_chain}
| writing_prompt
| llm
| StrOutputParser()
)
# Test the sequential chain
topic = "Artificial Intelligence in Healthcare"
result = sequential_chain.invoke({"topic": topic})
print(f"Article about {topic}:")
print(result) # Classification prompt to determine question type
classification_prompt = ChatPromptTemplate.from_template(
"Classify the following question into one of these categories: physics, math, history, or general.\n"
"Question: {question}\n"
"Category (respond with just the category name):"
)
# Specialized prompts for different subjects
physics_prompt = ChatPromptTemplate.from_template(
"You are a physics expert. Answer this physics question in detail:\n{question}"
)
math_prompt = ChatPromptTemplate.from_template(
"You are a mathematics expert. Solve this math problem step by step:\n{question}"
)
history_prompt = ChatPromptTemplate.from_template(
"You are a history expert. Provide a comprehensive answer to this history question:\n{question}"
)
general_prompt = ChatPromptTemplate.from_template(
"Answer the following question:\n{question}"
)
# Create classifier
classifier = classification_prompt | llm | StrOutputParser()
# Create routing function
def route_question(info):
category = info["category"].lower().strip()
question = info["question"]
if "physics" in category:
return physics_prompt.format(question=question)
elif "math" in category:
return math_prompt.format(question=question)
elif "history" in category:
return history_prompt.format(question=question)
else:
return general_prompt.format(question=question)
# Create the routing chain
routing_chain = (
{
"question": RunnablePassthrough(),
"category": classifier
}
| RunnableLambda(route_question)
| llm
| StrOutputParser()
) # Test the routing chain with different types of questions
questions = [
"What is Newton's second law of motion?",
"Solve for x: 2x + 5 = 15",
"When did World War II end?",
"What's the best programming language to learn?"
]
for question in questions:
print(f"Question: {question}")
result = routing_chain.invoke({"question": question})
print(f"Answer: {result}")
print("-" * 50) from langchain.memory import ConversationBufferMemory
from langchain_core.runnables import RunnableWithMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
# Store for session histories
store = {}
def get_session_history(session_id: str) -> BaseChatMessageHistory:
if session_id not in store:
store[session_id] = ChatMessageHistory()
return store[session_id]
# Create conversation prompt
conversation_prompt = ChatPromptTemplate.from_messages([
("system", "You are a helpful AI assistant. Have a natural conversation with the user."),
("placeholder", "{chat_history}"),
("human", "{input}")
])
# Create the conversation chain
conversation_chain = conversation_prompt | llm | StrOutputParser()
# Add message history
conversation_with_history = RunnableWithMessageHistory(
conversation_chain,
get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
) # Test conversation with context
session_id = "user_123"
# First interaction
response1 = conversation_with_history.invoke(
{"input": "My name is Alice and I'm learning Python."},
config={"configurable": {"session_id": session_id}}
)
print("Human: My name is Alice and I'm learning Python.")
print(f"AI: {response1}")
print("-" * 50)
# Second interaction - should remember the name
response2 = conversation_with_history.invoke(
{"input": "What's a good first project for me?"},
config={"configurable": {"session_id": session_id}}
)
print("Human: What's a good first project for me?")
print(f"AI: {response2}")
print("-" * 50)
# Third interaction - should remember both name and context
response3 = conversation_with_history.invoke(
{"input": "How long might that take me?"},
config={"configurable": {"session_id": session_id}}
)
print("Human: How long might that take me?")
print(f"AI: {response3}") # Step 1: Code analysis prompt
analysis_prompt = ChatPromptTemplate.from_template(
"Analyze the following code and identify:\n"
"1. Potential bugs or errors\n"
"2. Code quality issues\n"
"3. Performance concerns\n"
"4. Security vulnerabilities\n\n"
"Code:\n{code}\n\n"
"Analysis:"
)
# Step 2: Improvement suggestions prompt
improvement_prompt = ChatPromptTemplate.from_template(
"Based on the code analysis, provide specific improvement suggestions:\n\n"
"Original Code:\n{code}\n\n"
"Analysis:\n{analysis}\n\n"
"Improvement Suggestions:"
)
# Step 3: Refactored code prompt
refactor_prompt = ChatPromptTemplate.from_template(
"Refactor the following code based on the analysis and suggestions:\n\n"
"Original Code:\n{code}\n\n"
"Analysis:\n{analysis}\n\n"
"Suggestions:\n{suggestions}\n\n"
"Refactored Code:"
)
# Create individual chains
analysis_chain = analysis_prompt | llm | StrOutputParser()
improvement_chain = improvement_prompt | llm | StrOutputParser()
refactor_chain = refactor_prompt | llm | StrOutputParser()
# Create the complete code review chain
code_review_chain = (
{
"code": RunnablePassthrough(),
"analysis": analysis_chain
}
| {
"code": lambda x: x["code"],
"analysis": lambda x: x["analysis"],
"suggestions": improvement_chain
}
| {
"analysis": lambda x: x["analysis"],
"suggestions": lambda x: x["suggestions"],
"refactored_code": refactor_chain
}
) # Test with sample code
sample_code = """
def calculate_average(numbers):
total = 0
for i in range(len(numbers)):
total = total + numbers[i]
average = total / len(numbers)
return average
nums = [1, 2, 3, 4, 5]
result = calculate_average(nums)
print("Average is:", result)
"""
review_result = code_review_chain.invoke({"code": sample_code})
print("ANALYSIS:")
print(review_result["analysis"])
print("\n" + "="*50 + "\n")
print("SUGGESTIONS:")
print(review_result["suggestions"])
print("\n" + "="*50 + "\n")
print("REFACTORED CODE:")
print(review_result["refactored_code"]) from langchain_core.runnables import RunnableParallel
# Different analysis prompts that can run in parallel
sentiment_prompt = ChatPromptTemplate.from_template(
"Analyze the sentiment of this text (positive, negative, or neutral):\n{text}"
)
summary_prompt = ChatPromptTemplate.from_template(
"Provide a brief summary of this text:\n{text}"
)
keywords_prompt = ChatPromptTemplate.from_template(
"Extract the main keywords from this text:\n{text}"
)
# Create individual chains
sentiment_chain = sentiment_prompt | llm | StrOutputParser()
summary_chain = summary_prompt | llm | StrOutputParser()
keywords_chain = keywords_prompt | llm | StrOutputParser()
# Create parallel chain
parallel_analysis = RunnableParallel(
sentiment=sentiment_chain,
summary=summary_chain,
keywords=keywords_chain
)
# Test with sample text
sample_text = """
Artificial Intelligence is revolutionizing healthcare by enabling more accurate diagnoses,
personalized treatment plans, and efficient drug discovery. Machine learning algorithms
can analyze medical images with unprecedented precision, while natural language processing
helps extract insights from clinical notes. However, challenges remain in ensuring data
privacy, algorithmic fairness, and regulatory compliance.
"""
results = parallel_analysis.invoke({"text": sample_text})
print("SENTIMENT ANALYSIS:")
print(results["sentiment"])
print("\nSUMMARY:")
print(results["summary"])
print("\nKEYWORDS:")
print(results["keywords"]) from langchain_core.runnables import RunnableLambda
from langchain_core.exceptions import OutputParserException
def safe_invoke(chain, inputs, max_retries=3):
"""
Safely invoke a chain with error handling and retries.
"""
for attempt in range(max_retries):
try:
result = chain.invoke(inputs)
return {"success": True, "result": result}
except OutputParserException as e:
print(f"Attempt {attempt + 1} failed with parser error: {e}")
if attempt == max_retries - 1:
return {"success": False, "error": f"Parser error after {max_retries} attempts: {e}"}
except Exception as e:
print(f"Attempt {attempt + 1} failed with error: {e}")
if attempt == max_retries - 1:
return {"success": False, "error": f"Failed after {max_retries} attempts: {e}"}
return {"success": False, "error": "Unexpected error occurred"}
# Create a validation function
def validate_output(output):
if len(output.strip()) < 10:
raise ValueError("Output too short, likely incomplete")
return output
# Chain with validation
validated_chain = (
research_prompt
| llm
| StrOutputParser()
| RunnableLambda(validate_output)
)
# Test error handling
result = safe_invoke(validated_chain, {"topic": "Machine Learning"})
if result["success"]:
print("Chain executed successfully!")
print(f"Result length: {len(result['result'])} characters")
else:
print(f"Chain failed: {result['error']}")