0%

LangGraph实战——构建复杂AI工作流

最近研究LangGraph,发现它就像给AI智能体装上了”大脑”,通过节点和边构建思维链条,让AI能像人一样进行复杂推理和决策…

介绍

  LangGraph是一个专门为构建复杂AI工作流和智能代理系统而设计的图框架。它将传统的状态机概念与现代大语言模型相结合,提供了强大而灵活的方式来设计和执行AI驱动的决策流程。本文将深入探讨LangGraph的核心概念、实际应用和最佳实践。

LangGraph核心概念

图结构基础

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
# LangGraph图结构实现
from typing import TypedDict, List, Dict, Any, Optional
from langgraph.graph import StateGraph, END
from langchain_core.pydantic_v1 import BaseModel
import asyncio
import json

class GraphState(TypedDict):
"""图状态定义"""
input_query: str
context: str
intermediate_steps: List[Dict[str, Any]]
final_answer: Optional[str]
error: Optional[str]
metadata: Dict[str, Any]
execution_history: List[str]

class Document(BaseModel):
"""文档模型"""
content: str
source: str
metadata: Dict[str, Any]

class LangGraphNode:
"""LangGraph节点基类"""
def __init__(self, name: str, description: str = ""):
self.name = name
self.description = description
self.dependencies = []
self.outputs = []

async def execute(self, state: GraphState) -> GraphState:
"""执行节点逻辑"""
raise NotImplementedError

def add_dependency(self, node: 'LangGraphNode'):
"""添加依赖节点"""
self.dependencies.append(node)

def add_output(self, node: 'LangGraphNode'):
"""添加输出节点"""
self.outputs.append(node)

class ResearchNode(LangGraphNode):
"""研究节点 - 负责信息收集"""
def __init__(self, name: str, llm_client, search_tool):
super().__init__(name, "收集相关信息的节点")
self.llm_client = llm_client
self.search_tool = search_tool

async def execute(self, state: GraphState) -> GraphState:
"""执行研究步骤"""
try:
query = state['input_query']

# 使用搜索工具收集信息
search_results = await self.search_tool.search(query)

# 将搜索结果添加到上下文中
context = state.get('context', '')
new_context = f"{context}\n\n搜索结果:\n{search_results}"

# 记录中间步骤
intermediate_step = {
"node": self.name,
"action": "research",
"input": query,
"output": search_results,
"timestamp": asyncio.get_event_loop().time()
}

state['context'] = new_context
state['intermediate_steps'].append(intermediate_step)
state['execution_history'].append(f"{self.name}: Completed research")

return state

except Exception as e:
state['error'] = f"Research node error: {str(e)}"
state['execution_history'].append(f"{self.name}: Error - {str(e)}")
return state

class AnalysisNode(LangGraphNode):
"""分析节点 - 负责信息分析"""
def __init__(self, name: str, llm_client):
super().__init__(name, "分析收集到的信息")
self.llm_client = llm_client

async def execute(self, state: GraphState) -> GraphState:
"""执行分析步骤"""
try:
context = state['context']
query = state['input_query']

# 构建分析提示
analysis_prompt = f"""
请分析以下信息并提取关键要点:

问题:{query}

信息:{context}

请提供:
1. 主要观点
2. 争议点
3. 缺失信息
4. 逻辑关系
"""

# 使用LLM进行分析
analysis_result = await self.llm_client.acall(analysis_prompt)

# 更新上下文
state['context'] = f"{state['context']}\n\n分析结果:\n{analysis_result}"

# 记录中间步骤
intermediate_step = {
"node": self.name,
"action": "analysis",
"input": context,
"output": analysis_result,
"timestamp": asyncio.get_event_loop().time()
}

state['intermediate_steps'].append(intermediate_step)
state['execution_history'].append(f"{self.name}: Completed analysis")

return state

except Exception as e:
state['error'] = f"Analysis node error: {str(e)}"
state['execution_history'].append(f"{self.name}: Error - {str(e)}")
return state

class SynthesisNode(LangGraphNode):
"""综合节点 - 负责生成最终答案"""
def __init__(self, name: str, llm_client):
super().__init__(name, "综合信息生成最终答案")
self.llm_client = llm_client

async def execute(self, state: GraphState) -> GraphState:
"""执行综合步骤"""
try:
context = state['context']
query = state['input_query']

# 构建综合提示
synthesis_prompt = f"""
基于以下信息,请提供一个全面且准确的回答:

问题:{query}

背景信息:{context}

请提供:
1. 直接回答问题
2. 支持论据
3. 潜在限制或不确定性
4. 建议的后续步骤
"""

# 生成最终答案
final_answer = await self.llm_client.acall(synthesis_prompt)

# 更新状态
state['final_answer'] = final_answer

# 记录中间步骤
intermediate_step = {
"node": self.name,
"action": "synthesis",
"input": context,
"output": final_answer,
"timestamp": asyncio.get_event_loop().time()
}

state['intermediate_steps'].append(intermediate_step)
state['execution_history'].append(f"{self.name}: Generated final answer")

return state

except Exception as e:
state['error'] = f"Synthesis node error: {str(e)}"
state['execution_history'].append(f"{self.name}: Error - {str(e)}")
return state

class ValidationNode(LangGraphNode):
"""验证节点 - 验证答案质量"""
def __init__(self, name: str):
super().__init__(name, "验证答案的质量和准确性")
self.quality_threshold = 0.7

async def execute(self, state: GraphState) -> GraphState:
"""执行验证步骤"""
try:
answer = state.get('final_answer', '')

# 简单的质量验证逻辑
quality_score = self.assess_quality(answer, state['input_query'])

validation_result = {
"quality_score": quality_score,
"is_valid": quality_score >= self.quality_threshold,
"feedback": "Answer meets quality standards" if quality_score >= self.quality_threshold else "Answer needs improvement"
}

# 记录中间步骤
intermediate_step = {
"node": self.name,
"action": "validation",
"input": answer,
"output": validation_result,
"timestamp": asyncio.get_event_loop().time()
}

state['intermediate_steps'].append(intermediate_step)
state['execution_history'].append(f"{self.name}: Validation complete (score: {quality_score:.2f})")

if quality_score < self.quality_threshold:
state['needs_revision'] = True

return state

except Exception as e:
state['error'] = f"Validation node error: {str(e)}"
state['execution_history'].append(f"{self.name}: Error - {str(e)}")
return state

def assess_quality(self, answer: str, query: str) -> float:
"""评估答案质量"""
# 简单的质量评估指标
score = 0.0

# 长度评估
if len(answer) > 50:
score += 0.2

# 相关性评估(简化版)
query_words = set(query.lower().split())
answer_words = set(answer.lower().split())
overlap = len(query_words.intersection(answer_words))
if overlap > 0:
score += min(0.3, overlap * 0.05)

# 信息丰富度评估
if '因为' in answer or '所以' in answer or '原因' in answer:
score += 0.2

# 结构性评估
if len(answer.split('\n')) > 1:
score += 0.15

# 语法完整性
if '?' not in answer and len(answer) > 20:
score += 0.15

return min(score, 1.0)

条件分支与决策逻辑

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# LangGraph条件分支实现
class ConditionalRouter:
"""条件路由器"""
def __init__(self):
self.conditions = []

def add_condition(self, name: str, condition_func, target_node: str):
"""添加条件"""
self.conditions.append({
'name': name,
'func': condition_func,
'target': target_node
})

def route(self, state: GraphState) -> str:
"""根据条件选择下一个节点"""
for condition in self.conditions:
if condition['func'](state):
return condition['target']

return 'default' # 默认路径

class ComplexityEvaluator:
"""复杂度评估器"""
@staticmethod
def evaluate_complexity(query: str) -> str:
"""评估查询复杂度"""
simple_indicators = ['什么是', 'who is', 'when was', 'where is']
complex_indicators = ['为什么', '如何', '影响', '解决方案', 'compared to', 'analyze']

query_lower = query.lower()

simple_count = sum(1 for indicator in simple_indicators if indicator in query_lower)
complex_count = sum(1 for indicator in complex_indicators if indicator in query_lower)

if complex_count > simple_count:
return 'complex'
elif simple_count > 0:
return 'simple'
else:
# 基于长度和术语复杂度评估
word_count = len(query.split())
technical_terms = ['algorithm', 'framework', 'architecture', 'protocol', 'methodology',
'paradigm', 'heuristic', 'optimization', 'implementation']

tech_count = sum(1 for term in technical_terms if term in query_lower)

if word_count > 10 or tech_count > 0:
return 'complex'
else:
return 'simple'

class DynamicWorkflowBuilder:
"""动态工作流构建器"""
def __init__(self):
self.nodes = {}
self.edges = []

def build_workflow(self, query: str) -> StateGraph:
"""根据查询动态构建工作流"""
complexity = ComplexityEvaluator.evaluate_complexity(query)

# 创建图
workflow = StateGraph(GraphState)

# 根据复杂度选择节点
if complexity == 'simple':
# 简单查询:研究 -> 综合 -> 验证
workflow.add_node("research", ResearchNode("research", None, None))
workflow.add_node("synthesis", SynthesisNode("synthesis", None))
workflow.add_node("validation", ValidationNode("validation"))

workflow.add_edge("research", "synthesis")
workflow.add_edge("synthesis", "validation")
workflow.add_edge("validation", END)

workflow.set_entry_point("research")

else:
# 复杂查询:研究 -> 分析 -> 综合 -> 验证 -> (可选)迭代
workflow.add_node("research", ResearchNode("research", None, None))
workflow.add_node("analysis", AnalysisNode("analysis", None))
workflow.add_node("synthesis", SynthesisNode("synthesis", None))
workflow.add_node("validation", ValidationNode("validation"))

workflow.add_edge("research", "analysis")
workflow.add_edge("analysis", "synthesis")
workflow.add_edge("synthesis", "validation")

# 条件边:如果需要修订,则返回分析
workflow.add_conditional_edges(
"validation",
self.validation_router,
{
"revise": "analysis",
"complete": END
}
)

workflow.set_entry_point("research")

return workflow

def validation_router(self, state: GraphState):
"""验证节点路由决策"""
if state.get('needs_revision', False):
return "revise"
else:
return "complete"

高级图操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# LangGraph高级功能
class GraphOptimizer:
"""图优化器"""
def __init__(self):
self.optimization_strategies = [
self.merge_nodes,
self.parallelize_nodes,
self.cache_intermediate_results
]

def optimize_graph(self, graph: StateGraph) -> StateGraph:
"""优化图结构"""
optimized_graph = graph.copy()

for strategy in self.optimization_strategies:
optimized_graph = strategy(optimized_graph)

return optimized_graph

def merge_nodes(self, graph: StateGraph) -> StateGraph:
"""合并相似节点"""
# 识别可以合并的连续节点
nodes_to_merge = []

for node_name in graph.nodes:
node = graph.nodes[node_name]
if hasattr(node, 'can_merge_with_next'):
next_nodes = graph._get_next_nodes(node_name)
if len(next_nodes) == 1:
next_node_name = next_nodes[0]
next_node = graph.nodes[next_node_name]

if node.can_merge_with_next(next_node):
nodes_to_merge.append((node_name, next_node_name))

# 合并节点
for current_node, next_node in nodes_to_merge:
merged_node = self.create_merged_node(graph.nodes[current_node], graph.nodes[next_node])
graph.add_node(f"{current_node}_merged", merged_node)
graph.remove_edge(current_node, next_node)
graph.remove_node(next_node)

return graph

def parallelize_nodes(self, graph: StateGraph) -> StateGraph:
"""并行化独立节点"""
# 识别可以并行执行的节点
independent_nodes = []

for node_name in graph.nodes:
# 检查节点是否可以独立执行
dependencies = graph._get_previous_nodes(node_name)
if len(dependencies) == 0 or self.nodes_are_independent(graph, dependencies):
independent_nodes.append(node_name)

# 将独立节点分组并行执行
if len(independent_nodes) > 1:
parallel_group = ParallelNodeGroup(independent_nodes)
graph.add_node("parallel_group", parallel_group)

for node_name in independent_nodes:
graph.remove_node(node_name)

return graph

def cache_intermediate_results(self, graph: StateGraph) -> StateGraph:
"""缓存中间结果"""
# 为频繁访问的节点添加缓存
for node_name in graph.nodes:
node = graph.nodes[node_name]
graph.nodes[node_name] = CachedNode(node)

return graph

class ParallelNodeGroup:
"""并行节点组"""
def __init__(self, node_names: List[str]):
self.node_names = node_names
self.nodes = []

async def execute(self, state: GraphState) -> GraphState:
"""并行执行多个节点"""
# 创建并行任务
tasks = []
for node in self.nodes:
tasks.append(node.execute(state.copy()))

# 等待所有任务完成
results = await asyncio.gather(*tasks, return_exceptions=True)

# 合并结果
for result in results:
if isinstance(result, Exception):
state['error'] = str(result)
else:
# 合并状态(这里简化处理)
pass

return state

class CachedNode:
"""带缓存的节点"""
def __init__(self, original_node: LangGraphNode, cache_size: int = 1000):
self.original_node = original_node
self.cache = {}
self.cache_order = []
self.cache_size = cache_size

async def execute(self, state: GraphState) -> GraphState:
"""执行并缓存结果"""
# 生成缓存键
cache_key = self.generate_cache_key(state)

if cache_key in self.cache:
# 命中缓存
return self.cache[cache_key]

# 执行原始节点
result = await self.original_node.execute(state)

# 存储到缓存
self.store_in_cache(cache_key, result)

return result

def generate_cache_key(self, state: GraphState) -> str:
"""生成缓存键"""
import hashlib
state_str = str(sorted(state.items()))
return hashlib.md5(state_str.encode()).hexdigest()

def store_in_cache(self, key: str, value: GraphState):
"""存储到缓存"""
if len(self.cache) >= self.cache_size:
# 移除最老的缓存项
old_key = self.cache_order.pop(0)
del self.cache[old_key]

self.cache[key] = value
self.cache_order.append(key)

class AdaptiveWorkflow:
"""自适应工作流"""
def __init__(self, initial_graph: StateGraph):
self.graph = initial_graph
self.execution_history = []
self.performance_metrics = {}

async def execute_with_adaptation(self, initial_state: GraphState) -> GraphState:
"""执行并根据结果自适应调整"""
current_state = initial_state

# 执行原始工作流
final_state = await self.execute_original_graph(current_state)

# 分析执行结果并调整图结构
self.analyze_execution(final_state)
adapted_graph = self.adapt_graph_based_on_performance()

# 如果需要,重新执行调整后的工作流
if self.should_rerun_with_adaptation():
final_state = await self.execute_graph(adapted_graph, initial_state)

return final_state

async def execute_original_graph(self, state: GraphState) -> GraphState:
"""执行原始图"""
# 这里应该是实际的图执行逻辑
# 为演示目的,返回相同状态
return state

def analyze_execution(self, state: GraphState):
"""分析执行结果"""
execution_time = state['metadata'].get('execution_time', 0)
success_rate = state['metadata'].get('success_rate', 0)

self.performance_metrics = {
'average_execution_time': execution_time,
'success_rate': success_rate,
'error_frequency': state.get('error') is not None
}

def adapt_graph_based_on_performance(self) -> StateGraph:
"""基于性能调整图"""
optimizer = GraphOptimizer()
return optimizer.optimize_graph(self.graph)

def should_rerun_with_adaptation(self) -> bool:
"""判断是否需要重新运行调整后的工作流"""
# 简单的判断逻辑
poor_performance = (
self.performance_metrics.get('success_rate', 1.0) < 0.8 or
self.performance_metrics.get('average_execution_time', 0) > 30 # 30秒阈值
)

return poor_performance

实际应用案例

客服智能助手

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
# 客服智能助手LangGraph实现
class CustomerServiceWorkflow:
"""客服智能助手工作流"""
def __init__(self, llm_client, knowledge_base):
self.llm_client = llm_client
self.knowledge_base = knowledge_base
self.graph = self.build_customer_service_graph()

def build_customer_service_graph(self) -> StateGraph:
"""构建客服工作流图"""
graph = StateGraph(GraphState)

# 添加客服相关节点
graph.add_node("query_classifier", QueryClassificationNode(self.llm_client))
graph.add_node("information_retrieval", InformationRetrievalNode(self.knowledge_base))
graph.add_node("solution_generator", SolutionGenerationNode(self.llm_client))
graph.add_node("response_validator", ResponseValidationNode())
graph.add_node("escalation_handler", EscalationHandlerNode())

# 添加边缘连接
graph.add_edge("query_classifier", "information_retrieval")
graph.add_edge("information_retrieval", "solution_generator")
graph.add_edge("solution_generator", "response_validator")

# 条件边:根据验证结果决定是否升级
graph.add_conditional_edges(
"response_validator",
self.response_routing_decision,
{
"valid": END,
"escalate": "escalation_handler",
"retry": "solution_generator"
}
)

graph.add_edge("escalation_handler", END)

graph.set_entry_point("query_classifier")

return graph

def response_routing_decision(self, state: GraphState):
"""响应路由决策"""
validation_result = state.get('validation_result', {})

if validation_result.get('is_appropriate', True):
return "valid"
elif validation_result.get('requires_human_intervention', False):
return "escalate"
else:
return "retry"

class QueryClassificationNode(LangGraphNode):
"""查询分类节点"""
def __init__(self, llm_client):
super().__init__("query_classifier", "分类客户查询类型")
self.llm_client = llm_client
self.categories = [
"billing_inquiry", "technical_support", "product_info",
"complaint", "account_management", "other"
]

async def execute(self, state: GraphState) -> GraphState:
"""执行查询分类"""
query = state['input_query']

classification_prompt = f"""
请将以下客户查询分类到以下类别之一:
{', '.join(self.categories)}

查询:{query}

只需返回类别名称,无需其他说明。
"""

try:
category = await self.llm_client.acall(classification_prompt)
category = category.strip().lower()

if category not in self.categories:
category = "other"

state['metadata']['query_category'] = category

# 记录中间步骤
intermediate_step = {
"node": self.name,
"action": "classification",
"input": query,
"output": category,
"timestamp": asyncio.get_event_loop().time()
}

state['intermediate_steps'].append(intermediate_step)
state['execution_history'].append(f"{self.name}: Classified as {category}")

return state

except Exception as e:
state['error'] = f"Classification error: {str(e)}"
state['execution_history'].append(f"{self.name}: Error - {str(e)}")
state['metadata']['query_category'] = "other"
return state

class InformationRetrievalNode(LangGraphNode):
"""信息检索节点"""
def __init__(self, knowledge_base):
super().__init__("information_retrieval", "从知识库检索相关信息")
self.knowledge_base = knowledge_base

async def execute(self, state: GraphState) -> GraphState:
"""执行信息检索"""
query = state['input_query']
category = state['metadata'].get('query_category', 'other')

try:
# 根据类别检索相关信息
relevant_docs = await self.knowledge_base.retrieve_relevant_docs(query, category)

context = state.get('context', '')
new_context = f"{context}\n\n相关知识库信息:\n"

for doc in relevant_docs[:5]: # 取前5个文档
new_context += f"- {doc.content}\n"

state['context'] = new_context

# 记录中间步骤
intermediate_step = {
"node": self.name,
"action": "retrieval",
"input": query,
"output": f"Retrieved {len(relevant_docs)} documents",
"timestamp": asyncio.get_event_loop().time()
}

state['intermediate_steps'].append(intermediate_step)
state['execution_history'].append(f"{self.name}: Retrieved {len(relevant_docs)} documents")

return state

except Exception as e:
state['error'] = f"Information retrieval error: {str(e)}"
state['execution_history'].append(f"{self.name}: Error - {str(e)}")
return state

class SolutionGenerationNode(LangGraphNode):
"""解决方案生成节点"""
def __init__(self, llm_client):
super().__init__("solution_generator", "生成客户问题的解决方案")
self.llm_client = llm_client

async def execute(self, state: GraphState) -> GraphState:
"""执行解决方案生成"""
query = state['input_query']
context = state['context']
category = state['metadata'].get('query_category', 'other')

solution_prompt = f"""
作为一名专业的客户服务代表,请为以下客户问题提供解决方案:

客户问题:{query}

相关信息:{context}

问题类别:{category}

请提供:
1. 直接解决方案或答案
2. 操作步骤(如适用)
3. 预期结果
4. 注意事项或免责声明

保持友好、专业的语气。
"""

try:
solution = await self.llm_client.acall(solution_prompt)

state['final_answer'] = solution

# 记录中间步骤
intermediate_step = {
"node": self.name,
"action": "solution_generation",
"input": query,
"output": solution[:100] + "...", # 截断显示
"timestamp": asyncio.get_event_loop().time()
}

state['intermediate_steps'].append(intermediate_step)
state['execution_history'].append(f"{self.name}: Generated solution")

return state

except Exception as e:
state['error'] = f"Solution generation error: {str(e)}"
state['execution_history'].append(f"{self.name}: Error - {str(e)}")
return state

class ResponseValidationNode(LangGraphNode):
"""响应验证节点"""
def __init__(self):
super().__init__("response_validator", "验证响应的适当性")
self.appropriateness_threshold = 0.8
self.completeness_threshold = 0.7

async def execute(self, state: GraphState) -> GraphState:
"""执行响应验证"""
solution = state.get('final_answer', '')
query = state['input_query']

try:
# 使用LLM评估响应质量
validation_prompt = f"""
请评估以下客户服务响应的质量:

客户问题:{query}
客户服务响应:{solution}

请评估:
1. 相关性(0-1分)
2. 完整性(0-1分)
3. 专业性(0-1分)
4. 是否需要人工干预(是/否)

以JSON格式返回评估结果。
"""

validation_result = await self.llm_client.acall(validation_prompt)

try:
eval_result = json.loads(validation_result)
except json.JSONDecodeError:
eval_result = {"appropriateness": 0.5, "requires_human_intervention": True}

# 记录验证结果
state['validation_result'] = eval_result

# 记录中间步骤
intermediate_step = {
"node": self.name,
"action": "validation",
"input": solution,
"output": eval_result,
"timestamp": asyncio.get_event_loop().time()
}

state['intermediate_steps'].append(intermediate_step)
state['execution_history'].append(f"{self.name}: Validation complete")

return state

except Exception as e:
state['error'] = f"Response validation error: {str(e)}"
state['execution_history'].append(f"{self.name}: Error - {str(e)}")
return state

class EscalationHandlerNode(LangGraphNode):
"""升级处理节点"""
def __init__(self):
super().__init__("escalation_handler", "处理需要人工干预的情况")
self.escalation_reasons = []

async def execute(self, state: GraphState) -> GraphState:
"""执行升级处理"""
validation_result = state.get('validation_result', {})
original_query = state['input_query']

escalation_info = {
"original_query": original_query,
"validation_result": validation_result,
"escalation_timestamp": asyncio.get_event_loop().time(),
"reason": "High complexity or sensitive issue"
}

# 记录升级信息
state['escalation_info'] = escalation_info

# 生成人工处理友好的格式
human_readable_info = f"""
需要人工处理的客户查询:

原始查询:{original_query}

评估结果:{json.dumps(validation_result, ensure_ascii=False, indent=2)}

请人工处理此查询并提供解决方案。
"""

state['final_answer'] = human_readable_info

# 记录中间步骤
intermediate_step = {
"node": self.name,
"action": "escalation",
"input": original_query,
"output": escalation_info,
"timestamp": asyncio.get_event_loop().time()
}

state['intermediate_steps'].append(intermediate_step)
state['execution_history'].append(f"{self.name}: Escalated to human agent")

return state

数据分析工作流

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
# 数据分析LangGraph实现
class DataAnalysisWorkflow:
"""数据分析工作流"""
def __init__(self, llm_client, data_connector):
self.llm_client = llm_client
self.data_connector = data_connector
self.graph = self.build_data_analysis_graph()

def build_data_analysis_graph(self) -> StateGraph:
"""构建数据分析图"""
graph = StateGraph(GraphState)

# 添加数据分析节点
graph.add_node("query_understanding", QueryUnderstandingNode(self.llm_client))
graph.add_node("data_discovery", DataDiscoveryNode(self.data_connector))
graph.add_node("data_extraction", DataExtractionNode(self.data_connector))
graph.add_node("data_analysis", DataAnalysisNode(self.llm_client))
graph.add_node("visualization_planning", VisualizationPlanningNode(self.llm_client))
graph.add_node("report_generation", ReportGenerationNode(self.llm_client))

# 添加边缘连接
graph.add_edge("query_understanding", "data_discovery")
graph.add_edge("data_discovery", "data_extraction")
graph.add_edge("data_extraction", "data_analysis")
graph.add_edge("data_analysis", "visualization_planning")
graph.add_edge("visualization_planning", "report_generation")
graph.add_edge("report_generation", END)

graph.set_entry_point("query_understanding")

return graph

class QueryUnderstandingNode(LangGraphNode):
"""查询理解节点"""
def __init__(self, llm_client):
super().__init__("query_understanding", "理解数据分析查询意图")
self.llm_client = llm_client

async def execute(self, state: GraphState) -> GraphState:
"""执行查询理解"""
query = state['input_query']

understanding_prompt = f"""
请分析以下数据分析查询:

查询:{query}

请识别:
1. 分析目标
2. 涉及的实体/指标
3. 时间范围
4. 维度/分组要求
5. 排序或过滤需求
"""

try:
analysis = await self.llm_client.acall(understanding_prompt)

state['metadata']['query_analysis'] = analysis

# 记录中间步骤
intermediate_step = {
"node": self.name,
"action": "query_understanding",
"input": query,
"output": analysis,
"timestamp": asyncio.get_event_loop().time()
}

state['intermediate_steps'].append(intermediate_step)
state['execution_history'].append(f"{self.name}: Query understood")

return state

except Exception as e:
state['error'] = f"Query understanding error: {str(e)}"
state['execution_history'].append(f"{self.name}: Error - {str(e)}")
return state

class DataDiscoveryNode(LangGraphNode):
"""数据发现节点"""
def __init__(self, data_connector):
super().__init__("data_discovery", "发现相关数据源")
self.data_connector = data_connector

async def execute(self, state: GraphState) -> GraphState:
"""执行数据发现"""
query_analysis = state['metadata'].get('query_analysis', '')

try:
# 根据查询分析发现相关数据表
relevant_tables = await self.data_connector.discover_relevant_tables(query_analysis)

state['metadata']['relevant_tables'] = relevant_tables

# 记录中间步骤
intermediate_step = {
"node": self.name,
"action": "data_discovery",
"input": query_analysis,
"output": f"Found {len(relevant_tables)} tables: {relevant_tables}",
"timestamp": asyncio.get_event_loop().time()
}

state['intermediate_steps'].append(intermediate_step)
state['execution_history'].append(f"{self.name}: Discovered {len(relevant_tables)} tables")

return state

except Exception as e:
state['error'] = f"Data discovery error: {str(e)}"
state['execution_history'].append(f"{self.name}: Error - {str(e)}")
return state

class DataExtractionNode(LangGraphNode):
"""数据提取节点"""
def __init__(self, data_connector):
super().__init__("data_extraction", "从数据源提取数据")
self.data_connector = data_connector

async def execute(self, state: GraphState) -> GraphState:
"""执行数据提取"""
relevant_tables = state['metadata'].get('relevant_tables', [])
query_analysis = state['metadata'].get('query_analysis', '')

try:
# 根据查询分析构建查询
sql_query = await self.llm_client.acall(
f"为以下需求生成SQL查询:{query_analysis}\n相关表:{relevant_tables}"
)

# 执行数据提取
extracted_data = await self.data_connector.execute_query(sql_query)

state['metadata']['extracted_data'] = extracted_data
state['context'] = f"{state.get('context', '')}\n提取的数据:{extracted_data}"

# 记录中间步骤
intermediate_step = {
"node": self.name,
"action": "data_extraction",
"input": sql_query,
"output": f"Extracted {len(extracted_data)} records",
"timestamp": asyncio.get_event_loop().time()
}

state['intermediate_steps'].append(intermediate_step)
state['execution_history'].append(f"{self.name}: Extracted {len(extracted_data)} records")

return state

except Exception as e:
state['error'] = f"Data extraction error: {str(e)}"
state['execution_history'].append(f"{self.name}: Error - {str(e)}")
return state

class DataAnalysisNode(LangGraphNode):
"""数据分析节点"""
def __init__(self, llm_client):
super().__init__("data_analysis", "分析提取的数据")
self.llm_client = llm_client

async def execute(self, state: GraphState) -> GraphState:
"""执行数据分析"""
extracted_data = state['metadata'].get('extracted_data', [])
query = state['input_query']

analysis_prompt = f"""
请分析以下数据并回答查询:{query}

数据:{extracted_data[:10]} # 只取前10条以避免过长

请提供:
1. 关键发现
2. 趋势分析
3. 异常值或注意事项
4. 建议的行动方案
"""

try:
analysis_result = await self.llm_client.acall(analysis_prompt)

state['context'] = f"{state.get('context', '')}\n\n分析结果:{analysis_result}"
state['metadata']['analysis_result'] = analysis_result

# 记录中间步骤
intermediate_step = {
"node": self.name,
"action": "data_analysis",
"input": f"Data with {len(extracted_data)} records",
"output": analysis_result[:100] + "...",
"timestamp": asyncio.get_event_loop().time()
}

state['intermediate_steps'].append(intermediate_step)
state['execution_history'].append(f"{self.name}: Data analysis complete")

return state

except Exception as e:
state['error'] = f"Data analysis error: {str(e)}"
state['execution_history'].append(f"{self.name}: Error - {str(e)}")
return state

class VisualizationPlanningNode(LangGraphNode):
"""可视化规划节点"""
def __init__(self, llm_client):
super().__init__("visualization_planning", "规划数据可视化")
self.llm_client = llm_client

async def execute(self, state: GraphState) -> GraphState:
"""执行可视化规划"""
analysis_result = state['metadata'].get('analysis_result', '')
query = state['input_query']

visualization_prompt = f"""
基于以下分析结果,建议合适的数据可视化方案:

查询:{query}
分析结果:{analysis_result}

请建议:
1. 最合适的图表类型
2. 需要展示的关键指标
3. 颜色主题建议
4. 交互功能需求
"""

try:
viz_plan = await self.llm_client.acall(visualization_prompt)

state['metadata']['visualization_plan'] = viz_plan

# 记录中间步骤
intermediate_step = {
"node": self.name,
"action": "visualization_planning",
"input": query,
"output": viz_plan,
"timestamp": asyncio.get_event_loop().time()
}

state['intermediate_steps'].append(intermediate_step)
state['execution_history'].append(f"{self.name}: Visualization planned")

return state

except Exception as e:
state['error'] = f"Visualization planning error: {str(e)}"
state['execution_history'].append(f"{self.name}: Error - {str(e)}")
return state

class ReportGenerationNode(LangGraphNode):
"""报告生成节点"""
def __init__(self, llm_client):
super().__init__("report_generation", "生成最终分析报告")
self.llm_client = llm_client

async def execute(self, state: GraphState) -> GraphState:
"""执行报告生成"""
query = state['input_query']
analysis_result = state['metadata'].get('analysis_result', '')
viz_plan = state['metadata'].get('visualization_plan', '')

report_prompt = f"""
请生成一份完整的数据分析报告:

查询:{query}
分析结果:{analysis_result}
可视化方案:{viz_plan}

报告应包括:
1. 执行摘要
2. 详细分析
3. 关键发现
4. 可视化建议
5. 行动建议
6. 数据来源说明
"""

try:
final_report = await self.llm_client.acall(report_prompt)

state['final_answer'] = final_report

# 记录中间步骤
intermediate_step = {
"node": self.name,
"action": "report_generation",
"input": "Analysis results and viz plan",
"output": final_report[:100] + "...",
"timestamp": asyncio.get_event_loop().time()
}

state['intermediate_steps'].append(intermediate_step)
state['execution_history'].append(f"{self.name}: Report generated")

return state

except Exception as e:
state['error'] = f"Report generation error: {str(e)}"
state['execution_history'].append(f"{self.name}: Error - {str(e)}")
return state

性能优化与最佳实践

图执行优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# LangGraph性能优化
class ExecutionOptimizer:
"""执行优化器"""
def __init__(self):
self.execution_cache = {}
self.performance_monitors = []

def optimize_execution(self, graph: StateGraph, state: GraphState) -> tuple:
"""优化执行"""
# 检查缓存
cache_key = self.generate_cache_key(graph, state)

if cache_key in self.execution_cache:
return self.execution_cache[cache_key], True # 从缓存返回

# 执行图
start_time = asyncio.get_event_loop().time()
result = self.execute_graph_optimized(graph, state)
end_time = asyncio.get_event_loop().time()

execution_time = end_time - start_time

# 存储到缓存
self.execution_cache[cache_key] = result

# 记录性能指标
self.log_performance(execution_time, graph, state)

return result, False # 新执行

def execute_graph_optimized(self, graph: StateGraph, state: GraphState) -> GraphState:
"""优化的图执行"""
# 识别可以并行执行的节点
execution_plan = self.create_execution_plan(graph)

current_state = state

for step in execution_plan:
if step['type'] == 'sequential':
current_state = self.execute_sequential(step['nodes'], current_state)
elif step['type'] == 'parallel':
current_state = asyncio.run(self.execute_parallel(step['nodes'], current_state))

return current_state

def create_execution_plan(self, graph: StateGraph) -> List[Dict[str, any]]:
"""创建执行计划"""
plan = []

# 简化的拓扑排序
visited = set()
stack = []

def dfs(node):
visited.add(node)
for neighbor in graph._get_next_nodes(node):
if neighbor not in visited:
dfs(neighbor)
stack.append(node)

# 从入口点开始DFS
entry_points = graph.entry_points
for entry in entry_points:
if entry not in visited:
dfs(entry)

# 逆序得到拓扑排序
topo_order = stack[::-1]

# 将拓扑排序转换为执行计划
current_batch = []
for node in topo_order:
current_batch.append(node)

# 如果当前批次中的节点没有共同的后续节点,可以并行执行
next_nodes = set()
for n in current_batch:
next_nodes.update(graph._get_next_nodes(n))

if len(next_nodes) > 1:
# 检查这些后续节点是否相互独立
can_parallelize = all(
not self.have_common_dependencies(graph, [node], list(next_nodes))
for node in next_nodes
)

if can_parallelize:
plan.append({'type': 'parallel', 'nodes': current_batch[:]})
current_batch = []

if current_batch:
plan.append({'type': 'sequential', 'nodes': current_batch})

return plan

def have_common_dependencies(self, graph: StateGraph, nodes1: List[str], nodes2: List[str]) -> bool:
"""检查两组节点是否有共同依赖"""
deps1 = set()
deps2 = set()

for node in nodes1:
deps1.update(graph._get_previous_nodes(node))

for node in nodes2:
deps2.update(graph._get_previous_nodes(node))

return bool(deps1.intersection(deps2))

async def execute_parallel(self, node_names: List[str], state: GraphState) -> GraphState:
"""并行执行节点"""
tasks = []
for node_name in node_names:
node = state.graph.nodes[node_name]
tasks.append(node.execute(state))

results = await asyncio.gather(*tasks, return_exceptions=True)

# 合并结果
final_state = state.copy()
for result in results:
if not isinstance(result, Exception):
# 这里需要实现状态合并逻辑
pass
else:
final_state['error'] = str(result)

return final_state

def execute_sequential(self, node_names: List[str], state: GraphState) -> GraphState:
"""顺序执行节点"""
current_state = state

for node_name in node_names:
node = state.graph.nodes[node_name]
current_state = asyncio.run(node.execute(current_state))

return current_state

def generate_cache_key(self, graph: StateGraph, state: GraphState) -> str:
"""生成缓存键"""
import hashlib
graph_signature = graph.get_signature() # 假设有这个方法
state_hash = hashlib.md5(str(sorted(state.items())).encode()).hexdigest()
return f"{graph_signature}:{state_hash}"

def log_performance(self, execution_time: float, graph: StateGraph, state: GraphState):
"""记录性能指标"""
metrics = {
'execution_time': execution_time,
'graph_size': len(graph.nodes),
'state_size': len(str(state)),
'timestamp': asyncio.get_event_loop().time()
}

self.performance_monitors.append(metrics)

总结

  • LangGraph提供了强大的图状AI工作流构建能力
  • 节点设计模式便于模块化开发和维护
  • 条件分支和决策逻辑增强了工作流的灵活性
  • 实际应用案例展示了复杂场景的解决方案
  • 性能优化技术保障了系统的高效运行
  • 缓存和并行化提升了执行效率
  • 自适应调整使系统能够持续优化

LangGraph就像是AI世界的”乐高积木”,通过组合不同的节点和连接,我们可以构建出解决复杂问题的智能系统。这种图形化的思维方式让AI工作流的设计变得更加直观和可控。

未来发展

  1. 可视化工具: 更直观的图形化工作流设计器
  2. 自动化优化: AI驱动的工作流自动优化
  3. 实时监控: 详细的工作流执行监控和调试
  4. 标准化: 更统一的节点接口和工作流规范
  5. 集成能力: 与更多外部系统的无缝集成

扩展阅读

  1. LangGraph Documentation
  2. State Machines in AI Applications
  3. Graph Neural Networks for AI Workflows
bulb