Skip to content

32.2 企业安全配置

python
## 32.2.1 身份验证与授权

### API 密钥管理

#### 集中式密钥管理

class APIKeyManager: """API 密钥管理器"""

def **init**(self): self.vault_url = os.getenv('VAULT_ADDR') self.vault_token = os.getenv('VAULT_TOKEN') self.key_cache = {} self.cache_ttl = 3600 # 1 hour

def get_key(self, key_name: str) -> str: """获取 API 密钥"""

# 检查缓存

if key_name in self.key_cache: cached = self.key_cache[key_name] if time.time() - cached['timestamp'] < self.cache_ttl: return cached['key']

# 从 Vault 获取

key = self._fetch_from_vault(key_name)

# 缓存密钥

self.key_cache[key_name] = { 'key': key, 'timestamp': time.time() }

return key

python
def _fetch_from_vault(self, key_name: str) -> str: """从 Vault 获取密钥""" try: response = requests.get( f"{self.vault_url}/v1/secret/data/{key_name}", headers={'X-Vault-Token': self.vault_token} )

if response.status_code == 200: data = response.json() return data['data']['data']['value'] else: raise Exception(f"Failed to fetch key: {response.status_code}") except Exception as e: logger.error(f"Error fetching key from vault: {e}") raise

#### 密钥轮换策略

    bash


    python

    class KeyRotationManager:
        """密钥轮换管理器"""

        def __init__(self):
            self.rotation_schedule = {}
            self.rotation_history = []

        def schedule_rotation(self, key_name: str,
                            interval_days: int = 90):
            """安排密钥轮换"""
            next_rotation = datetime.now() + timedelta(days=interval_days)
            self.rotation_schedule[key_name] = {
                'interval_days': interval_days,
                'next_rotation': next_rotation,
                'last_rotation': None
            }

            logger.info(f"Scheduled rotation for {key_name} in {interval_days} days")

        def check_rotations(self) -> List[str]:
            """检查需要轮换的密钥"""
            now = datetime.now()
            keys_to_rotate = []

            for key_name, schedule in self.rotation_schedule.items():
                if schedule['next_rotation'] <= now:
                    keys_to_rotate.append(key_name)

            return keys_to_rotate

        def rotate_key(self, key_name: str) -> RotationResult:
            """轮换密钥"""
            result = RotationResult(key_name=key_name)

            try:
                # 生成新密钥
                new_key = self._generate_new_key()

                # 更新配置
                self._update_key_configuration(key_name, new_key)

                # 记录轮换
                self.rotation_history.append({
                    'key_name': key_name,
                    'rotated_at': datetime.now(),
                    'old_key_hash': self._hash_key(self._get_old_key(key_name)),
                    'new_key_hash': self._hash_key(new_key)
                })

                # 更新轮换计划
                self.rotation_schedule[key_name]['last_rotation'] = datetime.now()
                self.rotation_schedule[key_name]['next_rotation'] = \
                    datetime.now() + timedelta(
                        days=self.rotation_schedule[key_name]['interval_days']
                    )

                result.success = True
                result.new_key = new_key

            except Exception as e:
                result.success = False
                result.error = str(e)

            return result

    ### SSO 集成

    #### OAuth 2.0 配置

    class SSOAuthenticator:
    """SSO 认证器"""
    def __init__(self, config: Dict):
    self.client_id = config['client_id']
    self.client_secret = config['client_secret']
    self.redirect_uri = config['redirect_uri']
    self.auth_url = config['auth_url']
    self.token_url = config['token_url']
    self.scopes = config.get('scopes', ['openid', 'profile'])
    def get_auth_url(self, state: str = None) -> str:
    """获取认证 URL"""
    params = {
    'response_type': 'code',
    'client_id': self.client_id,
    'redirect_uri': self.redirect_uri,
    'scope': ' '.join(self.scopes),
    'state': state or self._generate_state()
    }
    return f"{self.auth_url}?{urllib.parse.urlencode(params)}"
    def exchange_code_for_token(self,
    auth_code: str) -> TokenResponse:
    """用授权码交换访问令牌"""
    data = {
    'grant_type': 'authorization_code',
    'code': auth_code,
    'client_id': self.client_id,
    'client_secret': self.client_secret,
    'redirect_uri': self.redirect_uri
    }
    response = requests.post(self.token_url, data=data)
    if response.status_code == 200:
    token_data = response.json()
    return TokenResponse(
    access_token=token_data['access_token'],
    refresh_token=token_data.get('refresh_token'),
    expires_in=token_data.get('expires_in', 3600),
    token_type=token_data.get('token_type', 'Bearer')
    )
    else:
    raise Exception(f"Token exchange failed: {response.status_code}")
    def refresh_access_token(self,
    refresh_token: str) -> TokenResponse:
    """刷新访问令牌"""
    data = {
    'grant_type': 'refresh_token',
    'refresh_token': refresh_token,
    'client_id': self.client_id,
    'client_secret': self.client_secret
    }
    response = requests.post(self.token_url, data=data)
    if response.status_code == 200:
    token_data = response.json()
    return TokenResponse(
    access_token=token_data['access_token'],
    refresh_token=token_data.get('refresh_token', refresh_token),
    expires_in=token_data.get('expires_in', 3600),
    token_type=token_data.get('token_type', 'Bearer')
    )
    else:
    raise Exception(f"Token refresh failed: {response.status_code}")
    def _generate_state(self) -> str:
    """生成状态参数"""
    return secrets.token_urlsafe(16)

### 多因素认证 (MFA)

    bash


    python

    class MFAAuthenticator:
        """MFA 认证器"""

        def __init__(self):
            self.mfa_methods = {
                'totp': self._verify_totp,
                'sms': self._verify_sms,
                'email': self._verify_email
            }

        def verify_mfa(self, method: str,
                       code: str,
                       user_id: str) -> bool:
            """验证 MFA 代码"""
            verifier = self.mfa_methods.get(method)

            if not verifier:
                raise ValueError(f"Unsupported MFA method: {method}")

            return verifier(code, user_id)

        def _verify_totp(self, code: str, user_id: str) -> bool:
            """验证 TOTP 代码"""
            # 获取用户的 TOTP 密钥
            secret = self._get_totp_secret(user_id)

            # 生成预期的代码
            totp = pyotp.TOTP(secret)
            expected_code = totp.now()

            # 验证代码(允许时间窗口)
            return totp.verify(code, valid_window=1)

        def _verify_sms(self, code: str, user_id: str) -> bool:
            """验证 SMS 代码"""
            # 从数据库获取发送的代码
            stored_code = self._get_stored_sms_code(user_id)

            # 验证代码
            return stored_code == code and not self._is_code_expired(user_id)

        def _verify_email(self, code: str, user_id: str) -> bool:
            """验证邮件代码"""
            # 从数据库获取发送的代码
            stored_code = self._get_stored_email_code(user_id)

            # 验证代码
            return stored_code == code and not self._is_code_expired(user_id)

    ## 32.2.2 权限控制

    ### 基于角色的访问控制 (RBAC)

    class RBACManager:
    """RBAC 管理器"""
    def __init__(self):
    self.roles = {}
    self.permissions = {}
    self.user_roles = {}
    def define_role(self, role_name: str,
    permissions: List[str]):
    """定义角色"""
    self.roles[role_name] = permissions
    logger.info(f"Role {role_name} defined with {len(permissions)} permissions")
    def assign_role(self, user_id: str, role_name: str):
    """为用户分配角色"""
    if role_name not in self.roles:
    raise ValueError(f"Role {role_name} not defined")
    if user_id not in self.user_roles:
    self.user_roles[user_id] = []
    if role_name not in self.user_roles[user_id]:
    self.user_roles[user_id].append(role_name)
    logger.info(f"Role {role_name} assigned to user {user_id}")
    def check_permission(self, user_id: str,
    permission: str) -> bool:
    """检查用户权限"""
    user_roles = self.user_roles.get(user_id, [])
    for role in user_roles:
    role_permissions = self.roles.get(role, [])
    if permission in role_permissions:
    return True
    return False
    def get_user_permissions(self, user_id: str) -> List[str]:
    """获取用户的所有权限"""
    user_roles = self.user_roles.get(user_id, [])
    all_permissions = set()
    for role in user_roles:
    role_permissions = self.roles.get(role, [])
    all_permissions.update(role_permissions)
    return list(all_permissions)

### 权限策略定义

    bash


    python

    class PermissionPolicy:
        """权限策略"""

        # 定义权限
        PERMISSIONS = {
            'code:generate': 'Generate code',
            'code:read': 'Read code',
            'code:write': 'Write code',
            'code:delete': 'Delete code',
            'file:read': 'Read files',
            'file:write': 'Write files',
            'file:delete': 'Delete files',
            'tool:execute': 'Execute tools',
            'config:manage': 'Manage configuration',
            'user:manage': 'Manage users'
        }

        # 定义角色
        ROLES = {
            'viewer': [
                'code:read',
                'file:read'
            ],
            'developer': [
                'code:read',
                'code:write',
                'code:generate',
                'file:read',
                'file:write',
                'tool:execute'
            ],
            'senior_developer': [
                'code:read',
                'code:write',
                'code:generate',
                'code:delete',
                'file:read',
                'file:write',
                'file:delete',
                'tool:execute'
            ],
            'admin': [
                'code:read',
                'code:write',
                'code:generate',
                'code:delete',
                'file:read',
                'file:write',
                'file:delete',
                'tool:execute',
                'config:manage',
                'user:manage'
            ]
        }

    ### 权限检查中间件

    class PermissionMiddleware:
    """权限检查中间件"""
    def __init__(self, rbac_manager: RBACManager):
    self.rbac_manager = rbac_manager
    def check_permission(self,
    user_id: str,
    required_permission: str) -> bool:
    """检查权限"""
    has_permission = self.rbac_manager.check_permission(
    user_id,
    required_permission
    )
    if not has_permission:
    logger.warning(
    f"Permission denied: user={user_id}, "
    f"permission={required_permission}"
    )
    return has_permission
    def require_permission(self, permission: str):
    """权限装饰器"""
    def decorator(func):
    def wrapper(*args, **kwargs):
    # 获取用户 ID
    user_id = self._get_user_id()
    # 检查权限
    if not self.check_permission(user_id, permission):
    raise PermissionError(
    f"Permission denied: {permission}"
    )
    # 执行函数
    return func(*args, **kwargs)
    return wrapper
    return decorator
    def _get_user_id(self) -> str:
    """获取当前用户 ID"""
    # 从上下文或会话中获取
    return os.getenv('USER_ID', 'anonymous')

## 32.2.3 审计日志

### 审计日志记录器

    bash


    python

    class AuditLogger:
        """审计日志记录器"""

        def __init__(self, config: Dict):
            self.log_file = config.get('log_file', 'audit.log')
            self.log_level = config.get('log_level', 'INFO')
            self.retention_days = config.get('retention_days', 90)

            # 配置日志
            self.logger = logging.getLogger('audit')
            self.logger.setLevel(getattr(logging, self.log_level))

            # 文件处理器
            handler = logging.FileHandler(self.log_file)
            handler.setFormatter(
                logging.Formatter(
                    '%(asctime)s - %(levelname)s - %(message)s'
                )
            )
            self.logger.addHandler(handler)

        def log_event(self, event: AuditEvent):
            """记录审计事件"""
            log_entry = {
                'timestamp': datetime.utcnow().isoformat(),
                'user_id': event.user_id,
                'action': event.action,
                'resource': event.resource,
                'result': event.result,
                'ip_address': event.ip_address,
                'user_agent': event.user_agent,
                'metadata': event.metadata
            }

            self.logger.info(json.dumps(log_entry))

        def log_api_call(self, user_id: str,
                        endpoint: str,
                        method: str,
                        status_code: int,
                        duration_ms: float):
            """记录 API 调用"""
            event = AuditEvent(
                user_id=user_id,
                action='API_CALL',
                resource=endpoint,
                result=str(status_code),
                metadata={
                    'method': method,
                    'duration_ms': duration_ms
                }
            )
            self.log_event(event)

        def log_file_access(self, user_id: str,
                          file_path: str,
                          action: str,
                          result: str):
            """记录文件访问"""
            event = AuditEvent(
                user_id=user_id,
                action=f'FILE_{action.upper()}',
                resource=file_path,
                result=result
            )
            self.log_event(event)

        def log_permission_check(self, user_id: str,
                               permission: str,
                               granted: bool):
            """记录权限检查"""
            event = AuditEvent(
                user_id=user_id,
                action='PERMISSION_CHECK',
                resource=permission,
                result='GRANTED' if granted else 'DENIED'
            )
            self.log_event(event)

        def cleanup_old_logs(self):
            """清理旧日志"""
            cutoff_date = datetime.now() - timedelta(days=self.retention_days)

            # 读取日志文件
            with open(self.log_file, 'r') as f:
                lines = f.readlines()

            # 过滤旧日志
            filtered_lines = []
            for line in lines:
                try:
                    log_entry = json.loads(line)
                    log_date = datetime.fromisoformat(log_entry['timestamp'])
                    if log_date > cutoff_date:
                        filtered_lines.append(line)
                except (json.JSONDecodeError, ValueError):
                    # 保留无法解析的行
                    filtered_lines.append(line)

            # 写回文件
            with open(self.log_file, 'w') as f:
                f.writelines(filtered_lines)

            logger.info(f"Cleaned up audit logs, removed {len(lines) - len(filtered_lines)} entries")

    ### 审计事件类型

    class AuditEvent:
    """审计事件"""
    def __init__(self,
    user_id: str,
    action: str,
    resource: str = None,
    result: str = None,
    ip_address: str = None,
    user_agent: str = None,
    metadata: Dict = None):
    self.user_id = user_id
    self.action = action
    self.resource = resource
    self.result = result
    self.ip_address = ip_address or self._get_client_ip()
    self.user_agent = user_agent or self._get_user_agent()
    self.metadata = metadata or {}
    def _get_client_ip(self) -> str:
    """获取客户端 IP"""
    # 从请求上下文中获取
    return os.getenv('REMOTE_ADDR', 'unknown')
    def _get_user_agent(self) -> str:
    """获取用户代理"""
    # 从请求头中获取
    return os.getenv('HTTP_USER_AGENT', 'unknown')

## 32.2.4 数据保护

### 数据分类

    bash


    python

    class DataClassifier:
        """数据分类器"""

        def __init__(self):
            self.classification_rules = {
                'public': {
                    'description': '可以公开访问的数据',
                    'examples': ['public documentation', 'open source code']
                },
                'internal': {
                    'description': '仅限内部访问的数据',
                    'examples': ['internal documentation', 'proprietary code']
                },
                'confidential': {
                    'description': '需要特殊保护的数据',
                    'examples': ['customer data', 'financial information']
                },
                'restricted': {
                    'description': '最高级别的保护',
                    'examples': ['PII', 'trade secrets']
                }
            }

        def classify(self, data: str,
                    context: Dict = None) -> str:
            """分类数据"""
            # 检查敏感关键词
            if self._contains_pii(data):
                return 'restricted'

            # 检查上下文
            if context:
                if context.get('source') == 'customer':
                    return 'confidential'
                elif context.get('access_level') == 'internal':
                    return 'internal'

            # 默认分类
            return 'public'

        def _contains_pii(self, data: str) -> bool:
            """检查是否包含 PII"""
            pii_patterns = [
                r'\b\d{3}-\d{2}-\d{4}\b',  # SSN
                r'\b\d{16}\b',  # Credit card
                r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'  # Email
            ]

            for pattern in pii_patterns:
                if re.search(pattern, data):
                    return True

            return False

    ### 数据脱敏

    class DataMasker:
    """数据脱敏器"""
    def __init__(self):
    self.masking_rules = {
    'email': self._mask_email,
    'phone': self._mask_phone,
    'ssn': self._mask_ssn,
    'credit_card': self._mask_credit_card,
    'ip_address': self._mask_ip_address
    }
    def mask_data(self, data: str,
    data_type: str = 'auto') -> str:
    """脱敏数据"""
    if data_type == 'auto':
    data_type = self._detect_data_type(data)
    masker = self.masking_rules.get(data_type)
    if masker:
    return masker(data)
    else:
    return data
    def _mask_email(self, email: str) -> str:
    """脱敏邮箱"""
    if '@' not in email:
    return email
    local, domain = email.split('@', 1)
    masked_local = local[0] + '***' + local[-1:] if len(local) > 3 else '***'
    return f"{masked_local}@{domain}"
    def _mask_phone(self, phone: str) -> str:
    """脱敏电话号码"""
    digits = re.sub(r'\D', '', phone)
    if len(digits) >= 10:
    return f"***-***-{digits[-4:]}"
    else:
    return '***-***'
    def _mask_ssn(self, ssn: str) -> str:
    """脱敏 SSN"""
    digits = re.sub(r'\D', '', ssn)
    if len(digits) == 9:
    return f"***-**-{digits[-4:]}"
    else:
    return '***-**-****'
    def _mask_credit_card(self, card: str) -> str:
    """脱敏信用卡号"""
    digits = re.sub(r'\D', '', card)
    if len(digits) >= 13:
    return f"****-****-****-{digits[-4:]}"
    else:
    return '****-****-****-****'
    def _detect_data_type(self, data: str) -> str:
    """检测数据类型"""
    if '@' in data and '.' in data.split('@')[1]:
    return 'email'
    elif re.match(r'^\d{3}-\d{2}-\d{4}$', data):
    return 'ssn'
    elif re.match(r'^\d{16}$', re.sub(r'\D', '', data)):
    return 'credit_card'
    elif re.match(r'^\d{10}$', re.sub(r'\D', '', data)):
    return 'phone'
    else:
    return 'unknown'

基于 MIT 许可发布 | 永久导航