Skip to content

29.3 代码生成模块

29.3.1 代码生成概述

代码生成模块是编程 Agent 的核心能力之一,它能够根据自然语言描述生成高质量的代码。代码生成涉及需求理解、架构设计、代码实现等多个环节。

代码生成流程

用户需求 ↓ 需求分析与理解 ↓ 架构设计 ↓ 代码实现 ↓ 代码验证 ↓ 优化与改进 ↓ 最终代码

python
## 29.3.2 需求分析

### 需求提取器

    python


    python

    class RequirementExtractor:
        """需求提取器"""

        def __init__(self, llm_client: LLMClient):
            self.llm_client = llm_client

        async def extract(self, user_request: str) -> Requirement:
            """提取需求"""
            prompt = f"""
            分析用户需求,提取关键信息:

            用户需求:{user_request}

            请提取以下信息:
            1. 功能需求(需要实现什么功能)
            2. 技术栈(使用的编程语言、框架等)
            3. 约束条件(性能、安全、兼容性等)
            4. 输入输出(预期的输入和输出)
            5. 特殊要求(代码风格、注释要求等)

            以 JSON 格式返回结果。
            """

            response = await self.llm_client.complete(prompt)
            return self._parse_requirement(response)

        def _parse_requirement(self, response: str) -> Requirement:
            """解析需求"""
            try:
                data = json.loads(response)
                return Requirement(
                    functional_requirements=data.get('functional_requirements', []),
                    tech_stack=data.get('tech_stack', {}),
                    constraints=data.get('constraints', {}),
                    inputs=data.get('inputs', []),
                    outputs=data.get('outputs', []),
                    special_requirements=data.get('special_requirements', {})
                )
            except json.JSONDecodeError:
                raise ValueError("Invalid requirement format")

    ```### 需求验证器

```python
    class RequirementValidator:
    """需求验证器"""
    def validate(self, requirement: Requirement) -> ValidationResult:
    """验证需求"""
    issues = []
    # 检查功能需求
    if not requirement.functional_requirements:
    issues.append("No functional requirements specified")
    # 检查技术栈
    if not requirement.tech_stack:
    issues.append("No tech stack specified")
    # 检查约束条件
    if 'performance' in requirement.constraints:
    perf = requirement.constraints['performance']
    if not isinstance(perf, dict) or 'max_time' not in perf:
    issues.append("Invalid performance constraint")
    return ValidationResult(
    valid=len(issues) == 0,
    issues=issues
    )

## 29.3.3 架构设计

### 架构设计器

    python
python

    ```python

    class ArchitectureDesigner:
        """架构设计器"""

        def __init__(self, llm_client: LLMClient):
            self.llm_client = llm_client
            self.design_patterns = self._load_design_patterns()

        async def design(self, requirement: Requirement) -> Architecture:
            """设计架构"""
            prompt = f"""
            根据需求设计软件架构:

            功能需求:{requirement.functional_requirements}
            技术栈:{requirement.tech_stack}
            约束条件:{requirement.constraints}

            请设计:
            1. 系统架构(模块划分、层次结构)
            2. 类设计(类、接口、继承关系)
            3. 数据结构(数据模型、存储方案)
            4. 接口设计(API、函数签名)
            5. 设计模式(适用的设计模式)

            以 JSON 格式返回架构设计。
            """

            response = await self.llm_client.complete(prompt)
            return self._parse_architecture(response)

        def _parse_architecture(self, response: str) -> Architecture:
            """解析架构"""
            try:
                data = json.loads(response)
                return Architecture(
                    system_architecture=data.get('system_architecture', {}),
                    class_design=data.get('class_design', []),
                    data_structures=data.get('data_structures', []),
                    interfaces=data.get('interfaces', []),
                    design_patterns=data.get('design_patterns', [])
                )
            except json.JSONDecodeError:
                raise ValueError("Invalid architecture format")

        def _load_design_patterns(self) -> Dict[str, DesignPattern]:
            """加载设计模式"""
            return {
                'singleton': DesignPattern(
                    name='Singleton',
                    description='确保一个类只有一个实例',
                   适用场景='需要全局唯一访问点'
                ),
                'factory': DesignPattern(
                    name='Factory',
                    description='创建对象的接口',
                    适用场景='需要灵活创建对象'
                ),
                'observer': DesignPattern(
                    name='Observer',
                    description='定义对象间的一对多依赖',
                    适用场景='需要事件通知机制'
                )
            }

    ```### 架构评估器

    class ArchitectureEvaluator:
    """架构评估器"""
    def evaluate(self, architecture: Architecture,
    requirement: Requirement) -> EvaluationResult:
    """评估架构"""
    scores = {}
    # 评估模块化
    scores['modularity'] = self._evaluate_modularity(architecture)
    # 评估可扩展性
    scores['extensibility'] = self._evaluate_extensibility(architecture)
    # 评估性能
    scores['performance'] = self._evaluate_performance(
    architecture,
    requirement
    )
    # 评估可维护性
    scores['maintainability'] = self._evaluate_maintainability(architecture)
    # 计算总分
    total_score = sum(scores.values()) / len(scores)
    return EvaluationResult(
    total_score=total_score,
    scores=scores,
    recommendations=self._generate_recommendations(scores)
    )
    def _evaluate_modularity(self, architecture: Architecture) -> float:
    """评估模块化"""
    # 检查模块划分
    modules = architecture.system_architecture.get('modules', [])
    if not modules:
    return 0.0
    # 模块越多,模块化程度越高
    score = min(len(modules) / 10.0, 1.0)
    return score
    def _evaluate_extensibility(self, architecture: Architecture) -> float:
    """评估可扩展性"""
    # 检查设计模式使用
    patterns = architecture.design_patterns
    if not patterns:
    return 0.5
    # 使用设计模式提高可扩展性
    score = 0.5 + min(len(patterns) / 5.0, 0.5)
    return score
    def _evaluate_performance(self, architecture: Architecture,
    requirement: Requirement) -> float:
    """评估性能"""
    # 检查性能约束
    constraints = requirement.constraints.get('performance', {})
    if not constraints:
    return 0.8  # 默认分数
    # 评估架构是否满足性能要求
    score = 0.8  # 基础分数
    # 检查缓存策略
    if 'caching' in architecture.system_architecture:
    score += 0.1
    # 检查并发处理
    if 'concurrency' in architecture.system_architecture:
    score += 0.1
    return min(score, 1.0)
    def _evaluate_maintainability(self, architecture: Architecture) -> float:
    """评估可维护性"""
    # 检查类设计
    classes = architecture.class_design
    if not classes:
    return 0.5
    # 评估类的复杂度
    avg_methods = sum(
    len(c.get('methods', [])) for c in classes
    ) / len(classes)
    # 方法数量适中,可维护性高
    if 5 <= avg_methods <= 15:
    score = 1.0
    elif avg_methods < 5:
    score = 0.8
    else:
    score = 0.6
    return score
    def _generate_recommendations(self,
    scores: Dict[str, float]) -> List[str]:
    """生成建议"""
    recommendations = []
    if scores['modularity'] < 0.7:
    recommendations.append(
    "建议增加模块划分,提高模块化程度"
    )
    if scores['extensibility'] < 0.7:
    recommendations.append(
    "建议使用更多设计模式,提高可扩展性"
    )
    if scores['maintainability'] < 0.7:
    recommendations.append(
    "建议简化类设计,降低复杂度"
    )
    return recommendations

## 29.3.4 代码实现

### 代码生成器

    python
python

    ```python

    class CodeGenerator:
        """代码生成器"""

        def __init__(self, llm_client: LLMClient):
            self.llm_client = llm_client
            self.code_templates = self._load_code_templates()

        async def generate(self, architecture: Architecture,
                          requirement: Requirement) -> GeneratedCode:
            """生成代码"""

            # 生成类代码

            class_codes = []
            for class_design in architecture.class_design:
                code = await self._generate_class_code(
                    class_design,
                    requirement
                )
                class_codes.append(code)

            # 生成接口代码

            interface_codes = []
            for interface in architecture.interfaces:
                code = await self._generate_interface_code(
                    interface,
                    requirement
                )
                interface_codes.append(code)

            # 生成主程序代码

            main_code = await self._generate_main_code(
                architecture,
                requirement
            )

            # 组合所有代码

            full_code = self._combine_codes(
                class_codes,
                interface_codes,
                main_code
            )

            return GeneratedCode(
                full_code=full_code,
                class_codes=class_codes,
                interface_codes=interface_codes,
                main_code=main_code
            )

        async def _generate_class_code(self, class_design: Dict,
                                      requirement: Requirement) -> str:
            """生成类代码"""
            prompt = f"""
            根据类设计生成代码:

            类名:{class_design.get('name')}
            方法:{class_design.get('methods', [])}
            属性:{class_design.get('attributes', [])}
            父类:{class_design.get('parent', 'None')}
            编程语言:{requirement.tech_stack.get('language', 'Python')}

            请生成完整的类代码,包括:
            1. 类定义
            2. 所有方法的实现
            3. 必要的注释
            4. 错误处理
            """

            return await self.llm_client.complete(prompt)

        async def _generate_interface_code(self, interface: Dict,
                                           requirement: Requirement) -> str:
            """生成接口代码"""
            prompt = f"""
            根据接口设计生成代码:

            接口名:{interface.get('name')}
            方法:{interface.get('methods', [])}
            编程语言:{requirement.tech_stack.get('language', 'Python')}

            请生成完整的接口代码。
            """

            return await self.llm_client.complete(prompt)

        async def _generate_main_code(self, architecture: Architecture,
                                      requirement: Requirement) -> str:
            """生成主程序代码"""
            prompt = f"""
            根据架构和需求生成主程序代码:

            功能需求:{requirement.functional_requirements}
            类:{[c.get('name') for c in architecture.class_design]}
            接口:{[i.get('name') for i in architecture.interfaces]}
            编程语言:{requirement.tech_stack.get('language', 'Python')}

            请生成主程序代码,包括:
            1. 初始化代码
            2. 主要业务逻辑
            3. 示例用法
            """

            return await self.llm_client.complete(prompt)

        def _combine_codes(self, class_codes: List[str],
                          interface_codes: List[str],
                          main_code: str) -> str:
            """组合代码"""
            combined = []

            # 添加导入

            combined.append("# Generated Code")
            combined.append("")

            # 添加接口

            if interface_codes:
                combined.append("# Interfaces")
                for code in interface_codes:
                    combined.append(code)
                    combined.append("")

            # 添加类

            if class_codes:
                combined.append("# Classes")
                for code in class_codes:
                    combined.append(code)
                    combined.append("")

            # 添加主程序

            combined.append("# Main Program")
            combined.append(main_code)

            return "\n".join(combined)

    ```### 代码优化器

    class CodeOptimizer:
    """代码优化器"""
    def __init__(self, llm_client: LLMClient):
    self.llm_client = llm_client
    async def optimize(self, code: str,
    requirement: Requirement) -> OptimizedCode:
    """优化代码"""
    # 分析代码问题
    issues = await self._analyze_issues(code)
    # 生成优化建议
    suggestions = await self._generate_suggestions(
    code,
    issues,
    requirement
    )
    # 应用优化
    optimized_code = await self._apply_optimizations(
    code,
    suggestions
    )
    return OptimizedCode(
    original_code=code,
    optimized_code=optimized_code,
    issues=issues,
    suggestions=suggestions
    )
    async def _analyze_issues(self, code: str) -> List[CodeIssue]:
    """分析代码问题"""
    prompt = f"""
    分析以下代码的问题:
    {code}
    请识别:
    1. 性能问题
    2. 安全问题
    3. 代码风格问题
    4. 潜在的 bug
    5. 可维护性问题
    以 JSON 格式返回问题列表。
    """
    response = await self.llm_client.complete(prompt)
    return self._parse_issues(response)
    async def _generate_suggestions(self, code: str,
    issues: List[CodeIssue],
    requirement: Requirement) -> List[Suggestion]:
    """生成优化建议"""
    prompt = f"""
    基于代码问题生成优化建议:
    代码:{code}
    问题:{issues}
    约束条件:{requirement.constraints}
    请生成具体的优化建议,包括:
    1. 问题描述
    2. 优化方案
    3. 预期效果
    以 JSON 格式返回建议列表。
    """
    response = await self.llm_client.complete(prompt)
    return self._parse_suggestions(response)
    async def _apply_optimizations(self, code: str,
    suggestions: List[Suggestion]) -> str:
    """应用优化"""
    optimized_code = code
    for suggestion in suggestions:
    if suggestion.applicable:
    optimized_code = await self._apply_suggestion(
    optimized_code,
    suggestion
    )
    return optimized_code
    async def _apply_suggestion(self, code: str,
    suggestion: Suggestion) -> str:
    """应用单个建议"""
    prompt = f"""
    应用以下优化建议到代码:
    原始代码:{code}
    优化建议:{suggestion.description}
    优化方案:{suggestion.solution}
    请返回优化后的代码。
    """
    return await self.llm_client.complete(prompt)

## 29.3.5 代码验证

### 代码验证器

    python
python

    ```python

    class CodeValidator:
        """代码验证器"""

        def __init__(self, tool_manager: ToolManager):
            self.tool_manager = tool_manager

        async def validate(self, code: str,
                          requirement: Requirement) -> ValidationResult:
            """验证代码"""

            results = []

            # 语法检查

            syntax_result = await self._check_syntax(code, requirement)
            results.append(syntax_result)

            # 类型检查

            type_result = await self._check_types(code, requirement)
            results.append(type_result)

            # 逻辑检查

            logic_result = await self._check_logic(code, requirement)
            results.append(logic_result)

            # 性能检查

            performance_result = await self._check_performance(
                code,
                requirement
            )
            results.append(performance_result)

            # 综合结果

            all_passed = all(r.passed for r in results)

            return ValidationResult(
                passed=all_passed,
                results=results,
                issues=self._collect_issues(results)
            )

        async def _check_syntax(self, code: str,
                               requirement: Requirement) -> CheckResult:
            """检查语法"""
            language = requirement.tech_stack.get('language', 'python')

            try:
                if language == 'python':
                    result = await self._check_python_syntax(code)
                else:
                    result = CheckResult(
                        check_type='syntax',
                        passed=True,
                        message=f"Syntax check for {language} not implemented"
                    )

                return result

            except Exception as e:
                return CheckResult(
                    check_type='syntax',
                    passed=False,
                    message=f"Syntax error: {str(e)}"
                )

        async def _check_python_syntax(self, code: str) -> CheckResult:
            """检查 Python 语法"""
            try:
                compile(code, '<string>', 'exec')
                return CheckResult(
                    check_type='syntax',
                    passed=True,
                    message="Syntax is valid"
                )
            except SyntaxError as e:
                return CheckResult(
                    check_type='syntax',
                    passed=False,
                    message=f"Syntax error at line {e.lineno}: {e.msg}"
                )

        async def _check_types(self, code: str,
                              requirement: Requirement) -> CheckResult:
            """检查类型"""

            # 使用类型检查工具

            tool = self.tool_manager.get_tool('type_checker')

            if not tool:
                return CheckResult(
                    check_type='type',
                    passed=True,
                    message="Type checker not available"
                )

            try:
                result = await tool.execute({'code': code})

                if result.success:
                    return CheckResult(
                        check_type='type',
                        passed=True,
                        message="Type check passed"
                    )
                else:
                    return CheckResult(
                        check_type='type',
                        passed=False,
                        message=f"Type check failed: {result.error}"
                    )
            except Exception as e:
                return CheckResult(
                    check_type='type',
                    passed=False,
                    message=f"Type check error: {str(e)}"
                )

        async def _check_logic(self, code: str,
                              requirement: Requirement) -> CheckResult:
            """检查逻辑"""

            # 分析代码逻辑

            issues = []

            # 检查空指针

            if 'None' in code and 'if' not in code:
                issues.append("Potential None reference without check")

            # 检查资源泄漏

            if 'open(' in code and 'close(' not in code:
                issues.append("Potential resource leak (file not closed)")

            if issues:
                return CheckResult(
                    check_type='logic',
                    passed=False,
                    message=f"Logic issues: {', '.join(issues)}"
                )
            else:
                return CheckResult(
                    check_type='logic',
                    passed=True,
                    message="Logic check passed"
                )

        async def _check_performance(self, code: str,
                                    requirement: Requirement) -> CheckResult:
            """检查性能"""
            issues = []

            # 检查嵌套循环

            if code.count('for ') > 2:
                issues.append("Deep nested loops may cause performance issues")

            # 检查大列表操作

            if 'list(' in code and 'range(' in code:
                issues.append("Consider using generator expressions for large ranges")

            if issues:
                return CheckResult(
                    check_type='performance',
                    passed=False,
                    message=f"Performance issues: {', '.join(issues)}"
                )
            else:
                return CheckResult(
                    check_type='performance',
                    passed=True,
                    message="Performance check passed"
                )

        def _collect_issues(self,
                           results: List[CheckResult]) -> List[str]:
            """收集所有问题"""
            issues = []

            for result in results:
                if not result.passed:
                    issues.append(result.message)

            return issues

通过实现这些组件,我们可以构建一个完整的代码生成模块,能够从需求分析到代码验证的全流程自动化。

基于 MIT 许可发布 | 永久导航