APR-MCTS原理

核心数据结构

Node节点结构

1
2
3
4
5
6
7
8
9
10
11
12
class Node:
def __init__(self, bug: Bug):
self.bug = bug # 当前bug状态
self.num_visits = 0 # 访问次数
self.V = 0 # 节点价值(奖励值)
self.parent = None # 父节点
self.children = [] # 子节点列表
self.motivation = None # 修复动机/反思
self.patches = [] # 生成的补丁
self.patch_diffs = [] # 补丁差异
self.can_fix = False # 是否能修复bug
self.is_fully_expand = False # 是否完全扩展

expand(node: Node)函数中,每次调用LLM生成修复补丁时,LLM的回复通常包含两部分:反思(reflection)和补丁(patch),通过extract_reflection_from_response(response)函数,从LLM的回复中提取反思内容(即motivation),通常是补丁代码前面的解释或思考,创建子节点时,将提取到的反思内容赋值给child.motivation

self.patch_diffs用于保存每个节点(Node)在扩展时生成的补丁(patch)与原始代码之间的差异信息(diff)。每次生成并验证一个补丁后,会通过 framework.validate_patch获取该补丁的 git diff(即代码变更内容),并将其存入 self.patch_diffs。这样可以方便后续分析、记录和展示每个补丁的具体修改细节。

Bug状态表示

每个节点包含一个Bug对象,记录:

  • 项目信息、bug ID、bug类型(SL/SH/SF)
  • 源代码、掩码代码、测试代码
  • 失败测试、错误信息等
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class Bug(object):
def __init__(self,
test_framework,
project,
bug_id,
bug_type,
code,
masked_code,
fixed_code,
buggy_lines,
fixed_lines,
test_code,
extract_test_code,
failing_tests,
test_suite,
test_name,
test_line,
test_error_message,
test_input=None,
test_output=None,
expected_output=None):
self.test_framework = test_framework
self.project = project
self.bug_id = bug_id
self.bug_type = bug_type
self.code = code
self.masked_code = masked_code
self.fixed_code = fixed_code
self.buggy_lines = buggy_lines
self.fixed_lines = fixed_lines
self.test_code = test_code
self.extract_test_code = extract_test_code
self.failing_tests = failing_tests
self.test_suite = test_suite
self.test_name = test_name
self.test_line = test_line
self.test_error_message = test_error_message
# condefects 新加的字段
self.test_input = test_input
self.test_output = test_output
self.expected_output = expected_output

MCTS四个核心步骤

Selection(选择阶段)

1
2
3
4
def select_node(node: Node):
while node.is_fully_expanded():
node = get_best_child(node) # 使用UCB公式选择最佳子节点
return is_terminal(node), node

使用**UCB(Upper Confidence Bound)**公式选择节点:

1
2
3
def get_best_child(node: Node):
node_value = child.V + exploration_constant * math.sqrt(
2 * math.log(node.num_visits) / child.num_visits)

Expansion(扩展阶段)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def expand(node: Node):
# 1. 构造修复提示词
repair_prompt = prompt.construct_gpt_policy_prompt(bug=bug, mode=mode)

# 2. 使用LLM生成多个候选补丁
responses = generate_patches_gpt(prompt=repair_prompt, num_samples=branch)

# 3. 验证每个补丁
for response in responses:
patch = extract_patch_from_response(response, mode=mode)
test_result, result_reason, patch_diff = framework.validate_patch(
bug=bug, proposed_patch=patch, mode=mode)

# 4. 创建子节点
if test_result == "PASS":
node.can_fix = True

# 5. 计算奖励值
reward = get_reward(bug, patch, reflection)
child.V = reward
node.children.append(child)

Simulation(模拟阶段)

跳过了传统的模拟阶段,直接使用LLM的反馈作为评估:

1
2
# skip simulating
print(f'Skip Simulating, Round={round_num}\n')

Backpropagation(反向传播)

1
2
3
4
5
6
7
8
9
def back_propagate(node):
while node is not None:
node.num_visits += 1
if node.is_fully_expanded():
# 使用加权平均更新节点价值
child_Vs = [child.V * child.num_visits for child in node.children]
total_num_visits = sum([child.num_visits for child in node.children])
node.V = alpha * sum(child_Vs) / total_num_visits + (1 - alpha) * node.V
node = node.parent

奖励机制设计

补丁验证结果

  • PASS: 补丁通过所有测试 → 终止搜索
  • FAIL: 测试失败但可编译 → 继续搜索
  • ERROR: 编译错误 → 奖励-1

LLM评分

使用专门的奖励模型对补丁质量打分(0-100分):

1
2
3
4
5
6
7
def get_reward(bug: Bug, wrong_patch, reflection):
reward_prompt = prompt.construct_gpt_reward_prompt(
bug=bug, wrong_patch=wrong_patch, reflection=reflection, mode=mode)
response = generate_gpt(reward_prompt, model_name=reward_model)
# 提取数字评分
score = extract_score_from_response(response)
return score / 100 # 归一化到0-1

搜索策略

  • 每次扩展生成branch个候选补丁(默认4个)
  • 最大扩展数max_expansion限制子节点数量
  • 最大rollout数max_rollout限制搜索轮次

支持的Bug类型

从defects4j.sh中的分类逻辑可以看出,系统根据Git差异将bug分为4种类型:

1
2
3
4
SINGLE_LINE="SL SH SF"    # 单行修复
SINGLE_HUNK="SH SF" # 单个代码块修复
SINGLE_FUNCTION="SF" # 单函数修复
OTHER="OT" # 其他类型(多函数/多文件修复)

分类标准:

  • 文件数量 > 2OT (Other)
  • 行变更 < 2SL (Single Line)
  • 同一函数内连续变更SH (Single Hunk)
  • 同一函数内非连续变更SF (Single Function)
  • 跨多个函数变更OT (Other)

不同的模式在应用代码patch时有区别:

  • SF模式(Single Function):补丁是一个完整的函数替换
  • 非SF模式:使用模板填空的方式,将补丁填入到 >>> [ INFILL ] <<< 标记位置

Prompt

Patch生成

主要包含以下几个核心部分:

  • **系统角色定义 **

    • “You are an automated program repair tool. Please do not use language features beyond java 1.4, such as foreach and generics <>.”
  • 问题描述

    • Single Line
      • 描述:"The following code contains a buggy line that has been removed."
      • 包含带有 >>> [ INFILL ] <<< 标记的掩码代码
      • 显示被移除的原始错误行
    • Single Hunk
      • 描述:"The following code contains a buggy hunk that has been removed."
      • 包含带有 >>> [ INFILL ] <<< 标记的掩码代码块
      • 显示被移除的原始错误代码块
    • Single Function
      • 描述:"The following code contains a bug"
      • 直接显示包含bug的完整函数代码
  • **代码上下文 **

    • bug.masked_code  // 掩码后的代码(SL/SH)
      bug.code     // 完整代码(SF)
      bug.buggy_lines  // 原始错误行/块
      
      1
      2
      3
      4
      5
      6

      - **测试用例 **

      - ```python
      Test cases look like:
      bug.extract_test_code // 提取的相关测试方法代码
  • 错误信息

    • The code fails with the following test error:
      bug.failing_tests // 失败的测试错误信息
      
  • 推理指导

    • “Before you give the final answer, let’s think step by step. You need to explain where bug happens and how your answer can avoid it.”
  • 输出格式要求

    • SL模式
      • "please provide the correct line at the infill location, only single line is allowed"
      • "your answer must be different from [原始错误行]"
      • "your answer should begin with ```java"
    • SH模式
      • "please provide the correct hunk at the infill location, only single hunk is allowed"
      • 类似SL的格式要求
    • SF模式
      • "please provide the correct function, starting with ```java"
  • 约束条件

    • 必须与原始错误代码不同
    • 必须以特定格式开始(```java)
    • 只能修改指定范围的代码