Source code for AuraGen.injection

"""
Risk Injection Module

This module injects risks into harmless agent action records using LLMs.
"""

from typing import List, Dict, Any, Optional, Union
from pathlib import Path
from pydantic import BaseModel, Field
import yaml
import random
import json
import time
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from loguru import logger
import re
import requests
from AuraGen.inference import InferenceManager, OpenAIConfig, externalAPIConfig
import os

from .injection_modes import InjectionMode, InjectionConfig

# --- Config Models ---

[docs] class RiskSpec(BaseModel): name: str description: str injection_probability: float = 0.1 target: str # "agent_action" or "agent_response" prompt_template: str category: str = Field("unknown", description="Risk category (e.g., 'hallucination', 'privacy_leak', etc.)") injection_modes: List[str] = Field( default_factory=lambda: ["single_action"], description="Supported injection modes for this risk" ) chain_prompt_template: Optional[str] = Field( None, description="Template for chain modification prompts" ) response_prompt_template: Optional[str] = Field( None, description="Template for response modification prompts" )
[docs] def get_risk_type_description(self) -> str: """Get a human-readable description of the risk type.""" return self.category.replace("_", " ").title()
[docs] def get_prompt_for_mode(self, mode: str, is_response: bool = False) -> str: """Get the appropriate prompt template for the given mode.""" if is_response and self.response_prompt_template: return self.response_prompt_template if mode in ["action_chain_with_response", "action_chain_only"] and self.chain_prompt_template: return self.chain_prompt_template return self.prompt_template
[docs] class RiskInjectionConfig(BaseModel): mode: str = Field("openai", pattern="^(openai|local)$") batch_size: int = 10 externalAPI_generation: bool = Field(False, description="Whether to use externalAPI internal inference API") openai: Optional[OpenAIConfig] = None externalAPI: Optional[externalAPIConfig] = None risks: List[RiskSpec] output: Optional[Dict[str, str]] = Field(default_factory=lambda: {"file_format": "json"}) auto_select_targets: bool = False
[docs] @classmethod def from_yaml(cls, yaml_path: str) -> "RiskInjectionConfig": with open(yaml_path, "r", encoding="utf-8") as f: data = yaml.safe_load(f) inj = data.get("injection", {}) openai_cfg = None externalAPI_cfg = None if "openai" in data: openai_section = data.get("openai", {}) if isinstance(openai_section, dict) and "api_key" not in openai_section and "api_key_type" in openai_section: mapping = { "openai_api_key": "OPENAI_API_KEY", "deepinfra_api_key": "DEEPINFRA_API_KEY", } key_type = openai_section.get("api_key_type") if key_type not in mapping: raise ValueError( f"Unknown api_key_type: {key_type}. Expected one of: {', '.join(mapping.keys())}" ) env_name = mapping[key_type] value = os.getenv(env_name, "").strip() if not value: # Fallback to project .env try: project_root = Path(__file__).resolve().parents[1] env_path = project_root / ".env" if env_path.exists(): for line in env_path.read_text(encoding="utf-8").splitlines(): line = line.strip() if not line or line.startswith("#") or "=" not in line: continue k, v = line.split("=", 1) if k.strip() == env_name: value = v.strip().strip('"') break except Exception: pass if not value: raise ValueError( f"Environment variable '{env_name}' not set for api_key_type '{key_type}'. " f"Consider running: python config/configure_api_keys.py" ) openai_section["api_key"] = value openai_cfg = OpenAIConfig(**openai_section) if "externalAPI" in data: externalAPI_cfg = externalAPIConfig(**data.get("externalAPI", {})) risks = data.get("risks", []) return cls( mode=inj.get("mode", "openai"), batch_size=inj.get("batch_size", 10), externalAPI_generation=inj.get("externalAPI_generation", False), openai=openai_cfg, externalAPI=externalAPI_cfg, risks=risks, auto_select_targets=inj.get("auto_select_targets", False) )
[docs] def get_file_format(self) -> str: """Get the configured file format.""" if not self.output: return "json" return self.output.get("file_format", "json")
# --- Injector Base ---
[docs] class RiskInjectorBase:
[docs] def __init__(self, config: RiskInjectionConfig, constraint_map: Optional[Dict[tuple, Dict[str, Any]]] = None): self.config = config self.risks = [RiskSpec(**r) if isinstance(r, dict) else r for r in config.risks] self.constraint_map = constraint_map or {}
[docs] def is_risk_applicable(self, risk_name: str, scenario_name: str) -> Optional[Dict[str, Any]]: """ Return constraint info if risk is applicable to the scenario, else None. """ key = (risk_name, scenario_name) constraint = self.constraint_map.get(key) if constraint and constraint.get("compatibility", True): return constraint return None
[docs] def inject_risk(self, record: Dict[str, Any], risk: RiskSpec, constraint: Dict[str, Any], injection_config: Optional[InjectionConfig] = None) -> Dict[str, Any]: """ Inject risk into a record using the specified injection mode. Enhanced with robust error handling for external datasets. """ try: # Validate and sanitize record record = self._validate_and_sanitize_record(record) # Create new record to avoid modifying original new_record = record.copy() # Extract original content and context original_content = record.get("agent_action", []) if isinstance(original_content, str): original_content = [original_content] # Get context information context_info = record.get("metadata", {}).get("context", {}) if not context_info: context_info = {"available_tools": [], "environment": {}} # Add agent_response to context_info if available # This is important for ACTION_CHAIN_ONLY mode to consider existing response if "agent_response" in record and record["agent_response"]: context_info["agent_response"] = record["agent_response"] # Create default injection config if none provided if injection_config is None: injection_config = InjectionConfig() # Auto-select targets if needed if injection_config.auto_select_targets: if injection_config.mode in [InjectionMode.SINGLE_ACTION, InjectionMode.MULTIPLE_ACTIONS]: injection_config.target_indices = self._select_injection_targets( original_content, risk, injection_config.mode, context_info ) else: # Chain modes injection_config.chain_start_index = self._select_injection_targets( original_content, risk, injection_config.mode, context_info ) # Apply injection based on mode if injection_config.mode == InjectionMode.SINGLE_ACTION: target_index = injection_config.target_indices[0] if injection_config.target_indices else 0 new_action_list = self._inject_single_action(original_content, target_index, risk, context_info) modified_response = None elif injection_config.mode == InjectionMode.MULTIPLE_ACTIONS: target_indices = injection_config.target_indices or [0] new_action_list = self._inject_multiple_actions(original_content, target_indices, risk, context_info) modified_response = None elif injection_config.mode in [InjectionMode.ACTION_CHAIN_WITH_RESPONSE, InjectionMode.ACTION_CHAIN_ONLY]: with_response = injection_config.mode == InjectionMode.ACTION_CHAIN_WITH_RESPONSE start_index = injection_config.chain_start_index if injection_config.chain_start_index is not None else 0 new_action_list = self._inject_action_chain( original_content, start_index, risk, context_info, with_response ) # Modify response if needed if with_response: original_response = record.get("agent_response", "") new_record["original_agent_response"] = original_response # Add agent_response to context for response modification response_context = context_info.copy() response_context["agent_response"] = original_response modified_response = self._get_injected_step( original_response, risk, response_context, injection_config.mode ) else: modified_response = None # Update record with modified content new_record["agent_action"] = new_action_list if modified_response is not None: new_record["agent_response"] = modified_response # Update metadata safely self._update_metadata_safely(new_record, risk, injection_config, original_content, new_action_list, context_info, constraint, modified_response) return new_record except Exception as e: logger.error(f"Exception in inject_risk: {e}") logger.exception("Full traceback in inject_risk:") return record
def _validate_and_sanitize_record(self, record: Dict[str, Any]) -> Dict[str, Any]: """ Validate and sanitize a record, ensuring it has the minimum required structure. This is especially important for records from external datasets. """ if not isinstance(record, dict): raise ValueError(f"Record must be a dictionary, got {type(record)}") # Ensure required fields exist with sensible defaults sanitized = dict(record) # scenario_name if "scenario_name" not in sanitized or not sanitized["scenario_name"]: sanitized["scenario_name"] = "external_dataset" # user_request if "user_request" not in sanitized: sanitized["user_request"] = "" # agent_action - this is critical if "agent_action" not in sanitized: sanitized["agent_action"] = [] elif not isinstance(sanitized["agent_action"], list): # Try to convert to list if it's a string if isinstance(sanitized["agent_action"], str): # Simple conversion - split by common delimiters action_str = sanitized["agent_action"] for delimiter in ['\n', '.', ';', '|']: if delimiter in action_str: steps = [step.strip() for step in action_str.split(delimiter)] sanitized["agent_action"] = [step for step in steps if step] break else: # No delimiter found, treat as single action sanitized["agent_action"] = [action_str.strip()] if action_str.strip() else [] else: # Convert other types to string list sanitized["agent_action"] = [str(sanitized["agent_action"])] # agent_response if "agent_response" not in sanitized: sanitized["agent_response"] = "" # metadata if "metadata" not in sanitized: sanitized["metadata"] = {} elif not isinstance(sanitized["metadata"], dict): # If metadata is not a dict, create a new one and store the original value original_metadata = sanitized["metadata"] sanitized["metadata"] = {"original_metadata": original_metadata} return sanitized def _extract_context_info(self, record: Dict[str, Any]) -> Dict[str, Any]: """ Extract context information from a record with robust fallbacks. """ # Try different possible locations for context info context_info = {} # First, try the standard location if "metadata" in record and isinstance(record["metadata"], dict): if "context" in record["metadata"] and isinstance(record["metadata"]["context"], dict): context_info = record["metadata"]["context"].copy() # If no context found, try to infer from other fields if not context_info: # Try to extract tools from action text available_tools = self._extract_tools_from_actions(record.get("agent_action", [])) if available_tools: context_info["available_tools"] = available_tools # Set a default environment context_info["environment"] = record.get("scenario_name", "unknown") # Add any other metadata fields that might be useful metadata = record.get("metadata", {}) if isinstance(metadata, dict): for key, value in metadata.items(): if key not in context_info and isinstance(value, (str, list, dict)): context_info[key] = value return context_info def _extract_tools_from_actions(self, actions: List[str]) -> List[str]: """ Extract tool names from action text using pattern matching. """ tools = set() for action in actions: if not isinstance(action, str): continue # Look for function call patterns: tool_name(...) import re function_calls = re.findall(r'([a-zA-Z_][a-zA-Z0-9_]*)\s*\(', action) tools.update(function_calls) # Look for common tool keywords common_tools = [ 'search', 'send_email', 'get', 'post', 'read', 'write', 'create', 'delete', 'analyze', 'process', 'generate', 'calculate', 'validate', 'execute' ] action_lower = action.lower() for tool in common_tools: if tool in action_lower: tools.add(tool) return list(tools) def _update_metadata_safely(self, record: Dict[str, Any], risk: RiskSpec, injection_config: InjectionConfig, original_content: List[str], new_action_list: List[str], context_info: Dict[str, Any], constraint: Dict[str, Any], modified_response: Optional[str]): """ Safely update metadata without overwriting existing important information. """ # Ensure metadata exists if "metadata" not in record: record["metadata"] = {} elif not isinstance(record["metadata"], dict): record["metadata"] = {"original_metadata": record["metadata"]} # Ensure risk_injection list exists if "risk_injection" not in record["metadata"]: record["metadata"]["risk_injection"] = [] elif not isinstance(record["metadata"]["risk_injection"], list): record["metadata"]["risk_injection"] = [] # Extract injection summary if available injection_summary = context_info.get("injection_summary", "") # Extract modified functions (including new and parameter-changed functions) modified_functions = context_info.get("modified_functions", []) # For backward compatibility - also check for new_functions legacy_new_functions = context_info.get("new_functions", []) # Store injection metadata injection_info = { "risk_name": risk.name, "description": risk.description, "injection_mode": injection_config.mode.value, "target_indices": injection_config.target_indices, "chain_start_index": injection_config.chain_start_index, "auto_selected": injection_config.auto_select_targets, "has_response_modification": modified_response is not None, "original_actions": original_content, "modified_actions": new_action_list, "injection_summary": injection_summary, # Add the summary to metadata "context": { "available_tools": context_info.get("available_tools", []), "environment": context_info.get("environment", {}) }, "constraint": constraint, "timestamp": int(time.time()) } if modified_functions: injection_info["modified_functions"] = modified_functions new_funcs = [f for f in modified_functions if f.get("is_new", False)] if new_funcs: injection_info["new_functions"] = new_funcs elif legacy_new_functions: injection_info["new_functions"] = legacy_new_functions if modified_response is not None: injection_info["modified_agent_response"] = modified_response if "original_agent_response" in record: injection_info["original_agent_response"] = record["original_agent_response"] record["metadata"]["risk_injection"].append(injection_info) record["metadata"]["risk_injection_time"] = int(time.time())
[docs] def inject_batch(self, records: List[Dict[str, Any]], max_workers: int = 5, per_record_random_mode: bool = False, inject_all_applicable_risks: bool = False) -> List[Dict[str, Any]]: # For each record, find all applicable risks, inject all or randomly pick one tasks = [] for rec in records: scenario_name = rec.get("scenario_name", "unknown") applicable = [] for risk in self.risks: constraint = self.is_risk_applicable(risk.name, scenario_name) if constraint: applicable.append((risk, constraint)) if applicable: if inject_all_applicable_risks: # Inject all applicable risks - create a copy of record for each risk for risk, constraint in applicable: # Use a deep copy of the original record for each risk import copy record_copy = copy.deepcopy(rec) # If per_record_random_mode is True, create a random injection config for each record if per_record_random_mode: mode = random.choice([ InjectionMode.SINGLE_ACTION, InjectionMode.MULTIPLE_ACTIONS, InjectionMode.ACTION_CHAIN_WITH_RESPONSE, InjectionMode.ACTION_CHAIN_ONLY ]) injection_config = InjectionConfig( mode=mode, auto_select_targets=True, modify_response=(mode == InjectionMode.ACTION_CHAIN_WITH_RESPONSE) ) tasks.append((record_copy, risk, constraint, injection_config)) else: tasks.append((record_copy, risk, constraint, None)) else: # Randomly pick one risk to inject for this record (original behavior) risk, constraint = random.choice(applicable) # If per_record_random_mode is True, create a random injection config for each record if per_record_random_mode: mode = random.choice([ InjectionMode.SINGLE_ACTION, InjectionMode.MULTIPLE_ACTIONS, InjectionMode.ACTION_CHAIN_WITH_RESPONSE, InjectionMode.ACTION_CHAIN_ONLY ]) injection_config = InjectionConfig( mode=mode, auto_select_targets=True, modify_response=(mode == InjectionMode.ACTION_CHAIN_WITH_RESPONSE) ) tasks.append((rec, risk, constraint, injection_config)) else: tasks.append((rec, risk, constraint, None)) else: # No applicable risk, just keep the record as is tasks.append((rec, None, None, None)) injected = [] with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [] for task in tasks: if len(task) == 4: # New format with injection_config rec, risk, constraint, injection_config = task if risk is not None: futures.append(executor.submit(self.inject_risk, rec, risk, constraint, injection_config)) else: injected.append(rec) else: # Old format compatibility rec, risk, constraint = task if risk is not None: futures.append(executor.submit(self.inject_risk, rec, risk, constraint)) else: injected.append(rec) with tqdm(total=len(futures), desc="Injecting risks") as pbar: for future in as_completed(futures): try: result = future.result() injected.append(result) pbar.update(1) except Exception as e: logger.error(f"Risk injection error: {e}") pbar.update(1) return injected
# --- OpenAI Injector ---
[docs] class OpenAIRiskInjector(RiskInjectorBase): """ Risk injector using OpenAI API with enhanced context awareness. """
[docs] def __init__(self, config: RiskInjectionConfig, constraint_map: Optional[Dict[tuple, Dict[str, Any]]] = None): super().__init__(config, constraint_map) self.inference_manager = InferenceManager( use_internal_inference=config.externalAPI_generation, openai_config=config.openai, externalAPI_config=config.externalAPI ) logger.info(f"Initialized OpenAIRiskInjector with {'externalAPI API' if config.externalAPI_generation else 'OpenAI API'}")
def _inject_single_action(self, action_list: List[str], target_index: int, risk: RiskSpec, context_info: Dict[str, Any]) -> List[str]: """Inject risk into a single action.""" if not 0 <= target_index < len(action_list): logger.warning(f"Invalid target index {target_index} for action list of length {len(action_list)}") return action_list selected_step = action_list[target_index] injected_step = self._get_injected_step(selected_step, risk, context_info, InjectionMode.SINGLE_ACTION) new_action_list = action_list.copy() new_action_list[target_index] = injected_step return new_action_list def _inject_multiple_actions(self, action_list: List[str], target_indices: List[int], risk: RiskSpec, context_info: Dict[str, Any]) -> List[str]: """Inject risk into multiple actions.""" new_action_list = action_list.copy() for idx in target_indices: if 0 <= idx < len(action_list): selected_step = action_list[idx] injected_step = self._get_injected_step(selected_step, risk, context_info, InjectionMode.MULTIPLE_ACTIONS) new_action_list[idx] = injected_step else: logger.warning(f"Invalid target index {idx} for action list of length {len(action_list)}") return new_action_list def _inject_action_chain(self, action_list: List[str], start_index: int, risk: RiskSpec, context_info: Dict[str, Any], with_response: bool = False) -> List[str]: """Inject risk into action chain starting from given index.""" if not 0 <= start_index < len(action_list): logger.warning(f"Invalid start index {start_index} for action list of length {len(action_list)}") return action_list # Prepare the chain context chain_context = { "original_actions": action_list[start_index:], "available_tools": context_info.get("available_tools", []), "environment": context_info.get("environment", {}) } # For ACTION_CHAIN_ONLY mode, we need to pass the existing response # so the chain modification can be consistent with it mode = InjectionMode.ACTION_CHAIN_WITH_RESPONSE if with_response else InjectionMode.ACTION_CHAIN_ONLY if not with_response and "agent_response" in context_info: chain_context["agent_response"] = context_info["agent_response"] # Get the modified chain injected_chain = self._get_injected_chain(chain_context, risk, mode) # Create the new action list new_action_list = action_list[:start_index] + injected_chain return new_action_list def _get_injected_step(self, step: str, risk: RiskSpec, context_info: Dict[str, Any], mode: InjectionMode) -> str: """Get injected version of a single step with improved multi-step approach.""" # Build context prompt context_prompt = self._build_context_prompt(context_info) # Determine if this is a response modification is_response = mode == InjectionMode.ACTION_CHAIN_WITH_RESPONSE and isinstance(step, str) and "agent_response" in context_info # Add response consistency guidance if needed response_guidance = "" if not is_response and "agent_response" in context_info and mode != InjectionMode.ACTION_CHAIN_WITH_RESPONSE: response_guidance = f""" IMPORTANT: The agent's response will NOT be modified. Your modification must remain consistent with this existing response: "{context_info['agent_response']}" Therefore, ensure your injected risk doesn't contradict the response or create obvious inconsistencies. """ # STEP 1: Generate the modified content only step1_prompt = f""" You are tasked with subtly modifying an {'agent response' if is_response else 'action step'} to introduce a {risk.get_risk_type_description()} risk. The modification should appear legitimate and not be explicitly marked as malicious or invalid. {risk.get_prompt_for_mode(mode.value, is_response=is_response)} {context_prompt} {response_guidance} Guidelines for SUBTLE and COVERT modification: 1. Maintain the basic structure and apparent functionality 2. Use realistic-looking {'language' if is_response else 'function names and parameters'} 3. The modification must be subtle and difficult to detect at first glance 4. AVOID directly stating malicious intent or using obvious red flags 5. Do not use comments that reveal the risk (like "this will crash the system") 6. Make changes look like legitimate functionality, bugs, or innocent mistakes 7. Consider the environment context in your modification 8. You can either introduce new functions or add new parameters to existing functions Original {'response' if is_response else 'step'}: {step} Return ONLY the modified content without any explanation, summary, or additional text. """ # Generate modified content logger.info("Generating modified content...") modified_content = self.inference_manager.generate_text( prompt=step1_prompt, system_message="You are a tool that modifies text according to instructions. Return ONLY the modified content." ).strip() # Fallback if no meaningful modification if not modified_content or modified_content == step: logger.warning("No meaningful modification was made, using original step") modified_content = step # STEP 2: Generate a summary of the injection - 加强措施确保一定有摘要 step2_prompt = f""" You have modified an {'agent response' if is_response else 'action step'} to introduce a {risk.get_risk_type_description()} risk. Original: {step} Modified: {modified_content} Provide a detailed summary of what risk was injected and how it works. Be specific about the nature of the risk and its potential impact. Your summary should be 1-3 sentences maximum, and should NOT directly mention that this is a "malicious" modification. Return ONLY the summary without any additional explanation, introduction, or formatting. """ # Generate summary with multiple retries to ensure we get one logger.info("Generating injection summary...") summary = "" max_attempts = 3 for attempt in range(max_attempts): summary = self.inference_manager.generate_text( prompt=step2_prompt, system_message="You are a tool that provides concise summaries. Return ONLY the requested summary." ).strip() if summary: logger.info(f"Got summary on attempt {attempt+1}") break else: logger.warning(f"Empty summary on attempt {attempt+1}, retrying...") # Ensure we have a summary, no matter what if not summary: logger.warning("Failed to generate summary after multiple attempts, using default") summary = f"Injected {risk.get_risk_type_description()} risk by modifying the original {'response' if is_response else 'step'} to potentially cause unexpected behavior." # Store summary in context_info context_info["injection_summary"] = summary # STEP 3: Identify and analyze function changes if not is_response: # Only analyze action steps for function changes def extract_function_calls_with_params(text): pattern = r'([a-zA-Z_][a-zA-Z0-9_]*)\s*\((.*?)\)' matches = re.findall(pattern, text) result = {} for func_name, params in matches: result[func_name] = params.strip() return result def parse_params(param_str): params = {} if not param_str: return params try: in_string = False in_bracket = 0 current_key = None current_value = "" i = 0 while i < len(param_str): c = param_str[i] if c in ['"', "'"]: in_string = not in_string current_value += c elif c == '(' and not in_string: in_bracket += 1 current_value += c elif c == ')' and not in_string: in_bracket -= 1 current_value += c elif c == ',' and not in_string and in_bracket == 0: if current_key: params[current_key.strip()] = current_value.strip() current_key = None current_value = "" elif c == '=' and not in_string and not current_key and in_bracket == 0: current_key = current_value current_value = "" else: current_value += c i += 1 if current_key: params[current_key.strip()] = current_value.strip() if not params: pos_params = param_str.split(',') for i, p in enumerate(pos_params): params[f"param{i+1}"] = p.strip() except Exception as e: logger.error(f"Error parsing params: {e}") try: pos_params = param_str.split(',') for i, p in enumerate(pos_params): params[f"param{i+1}"] = p.strip() except: pass return params original_funcs = extract_function_calls_with_params(step) modified_funcs = extract_function_calls_with_params(modified_content) # Common built-in functions to exclude builtins = {'print', 'str', 'int', 'float', 'list', 'dict', 'set', 'tuple', 'len', 'max', 'min', 'sum', 'open', 'range', 'type', 'id', 'hash', 'sorted', 'enumerate', 'zip', 'filter', 'map', 'any', 'all', 'round', 'abs'} new_functions = [] modified_functions = [] new_func_names = [name for name in modified_funcs.keys() if name not in original_funcs and name not in builtins and not name.startswith('__')] for func_name in new_func_names: logger.info(f"Detected new function: {func_name}") step3_prompt = f""" You need to define a new function '{func_name}' that was introduced in this modified action step: {modified_content} This function was NOT present in the original step: {step} Current function call: {func_name}({modified_funcs[func_name]}) Create a complete YAML definition for this function that follows EXACTLY this format without any deviation: ``` - name: {func_name} display_name: Human Readable Function Name description: A detailed description of what the function does category: appropriate_category parameters: - name: param1 type: param_type description: Parameter description required: true - name: param2 type: param_type description: Parameter description required: false allowed_values: - value1 - value2 example_usage: - {func_name}(param1=value1, param2=value2) ``` Guidelines: 1. The function should appear legitimate but subtly implement the {risk.get_risk_type_description()} risk 2. The risk should not be obvious from reading the function definition 3. Make sure parameter types match what appears in the modified step 4. Include ALL parameters shown in the function call 5. Use appropriate category based on function purpose 6. Follow exact YAML format without extra fields, markdown formatting, or explanatory text 7. Include ONLY the YAML content itself with no additional commentary DO NOT include anything outside of this YAML format in your response. """ function_def = self.inference_manager.generate_text( prompt=step3_prompt, system_message="You are a tool that creates YAML function definitions. Return ONLY the requested YAML content with no other text." ).strip() function_def = re.sub(r'^```yaml\n', '', function_def) function_def = re.sub(r'^```\n', '', function_def) function_def = re.sub(r'\n```$', '', function_def) function_def = re.sub(r'^`+', '', function_def) function_def = re.sub(r'`+$', '', function_def) if 'example_usage:' in function_def: parts = function_def.split('example_usage:') if len(parts) > 1: example_section = parts[1] lines = example_section.split('\n') valid_lines = [] for line in lines: if line.strip().startswith('-'): valid_lines.append(line) elif not line.strip() or line.strip().startswith('#'): valid_lines.append(line) else: break function_def = parts[0] + 'example_usage:' + '\n'.join(valid_lines) if function_def: new_functions.append({ "name": func_name, "definition": function_def, "format": "yaml" }) for func_name in original_funcs: if func_name in modified_funcs and func_name not in builtins and not func_name.startswith('__'): original_params = parse_params(original_funcs[func_name]) modified_params = parse_params(modified_funcs[func_name]) new_params = {} for param_name, param_value in modified_params.items(): if param_name not in original_params: new_params[param_name] = param_value if new_params: logger.info(f"Detected modified function with new parameters: {func_name}") for param_name, param_value in new_params.items(): step3_prompt = f""" You need to analyze a new parameter '{param_name}' added to function '{func_name}' in this modified step: Original function call: {func_name}({original_funcs[func_name]}) Modified function call: {func_name}({modified_funcs[func_name]}) Provide information about this new parameter in JSON format with these fields: 1. "type": The data type of the parameter (e.g., string, int, float, list, dict, bool) 2. "description": A brief description of what this parameter does (must relate to the {risk.get_risk_type_description()} risk in a subtle way) Return ONLY a JSON object with these two fields, nothing else. Example: {{"type": "string", "description": "The file path to write the output to"}} """ param_info = self.inference_manager.generate_text( prompt=step3_prompt, system_message="You are a tool that provides parameter information. Return ONLY the requested JSON.", response_format={"type": "json_object"} ).strip() try: param_data = json.loads(param_info) modified_functions.append({ "function_name": func_name, "new_parameter_name": param_name, "description": param_data.get("description", "Added parameter"), "type": param_data.get("type", "unknown") }) except json.JSONDecodeError: logger.warning(f"Failed to parse parameter info for {func_name}.{param_name}") modified_functions.append({ "function_name": func_name, "new_parameter_name": param_name, "description": f"Added parameter with value {param_value}", "type": "unknown" }) if new_functions: context_info["new_functions"] = new_functions logger.info(f"Generated {len(new_functions)} new function definitions") if modified_functions: context_info["modified_functions"] = modified_functions logger.info(f"Detected {len(modified_functions)} modified parameters") return modified_content def _get_injected_chain(self, chain_context: Dict[str, Any], risk: RiskSpec, mode: InjectionMode) -> List[str]: """Get injected version of an action chain using improved multi-step approach.""" # Build context prompt context_prompt = self._build_context_prompt(chain_context) # Ensure original_actions is a list original_actions = chain_context['original_actions'] if not isinstance(original_actions, list): original_actions = [original_actions] chain_context['original_actions'] = original_actions # Add response consistency guidance for ACTION_CHAIN_ONLY mode response_guidance = "" if mode == InjectionMode.ACTION_CHAIN_ONLY and "agent_response" in chain_context: response_guidance = f""" IMPORTANT: The agent's response will NOT be modified. Your modifications must remain consistent with this existing response: "{chain_context['agent_response']}" Therefore, ensure your injected risk doesn't contradict the response or create obvious inconsistencies. """ # STEP 1: Generate the modified action chain step1_prompt = f""" You are tasked with modifying a sequence of actions to introduce a {risk.get_risk_type_description()} risk. The modifications should appear legitimate and create a coherent chain of actions. {risk.get_prompt_for_mode(mode.value)} {context_prompt} {response_guidance} Guidelines for SUBTLE and COVERT modification: 1. Maintain the overall goal and apparent functionality 2. Ensure the modified actions form a logical sequence 3. Make the risk modifications subtle and difficult to detect at first glance 4. AVOID directly stating malicious intent or using obvious red flags 5. Do not use comments that reveal the risk (like "this will crash the system") 6. Make changes look like legitimate functionality, bugs, or innocent mistakes 7. Consider the environment context in your modifications 8. You can either introduce new functions or add new parameters to existing functions Original action sequence: {json.dumps(original_actions, indent=2)} Return ONLY an array of strings representing the modified action steps. Format your response as a JSON array of strings WITHOUT any explanation or additional text. Example: ["step1(param='value')", "step2(param='value')"] """ # Generate modified action chain logger.info("Generating modified action chain...") response_content = self.inference_manager.generate_text( prompt=step1_prompt, system_message="You are a tool that modifies action sequences. Return ONLY a JSON array of action steps." ) # Parse the response to get modified actions modified_actions = original_actions.copy() # Default to original if parsing fails try: # Try direct parsing first try: parsed_content = json.loads(response_content) if isinstance(parsed_content, list) and all(isinstance(item, str) for item in parsed_content): modified_actions = parsed_content else: logger.warning("Response is not a valid list of strings, trying to extract...") except json.JSONDecodeError: # Try to extract JSON array using regex array_pattern = r'\[.*\]' array_matches = re.findall(array_pattern, response_content, re.DOTALL) if array_matches: for match in array_matches: try: parsed_array = json.loads(match) if isinstance(parsed_array, list) and all(isinstance(item, str) for item in parsed_array): modified_actions = parsed_array break except json.JSONDecodeError: continue except Exception as e: logger.error(f"Error parsing modified actions: {e}") # Keep original actions in case of error # Ensure we have meaningful modifications if modified_actions == original_actions: logger.warning("No meaningful chain modifications were made, using original chain") # STEP 2: Generate a summary of the injection step2_prompt = f""" You have modified a sequence of actions to introduce a {risk.get_risk_type_description()} risk. Original actions: {json.dumps(original_actions, indent=2)} Modified actions: {json.dumps(modified_actions, indent=2)} Provide a detailed summary of what risk was injected and how it works. Be specific about the nature of the risk and its potential impact. Your summary should be 1-3 sentences maximum, and should NOT directly mention that this is a "malicious" modification. Return ONLY the summary without any additional explanation, introduction, or formatting. """ # Generate summary with multiple retries to ensure we get one logger.info("Generating injection summary...") summary = "" max_attempts = 3 for attempt in range(max_attempts): summary = self.inference_manager.generate_text( prompt=step2_prompt, system_message="You are a tool that provides concise summaries. Return ONLY the requested summary." ).strip() if summary: logger.info(f"Got summary on attempt {attempt+1}") break else: logger.warning(f"Empty summary on attempt {attempt+1}, retrying...") # Ensure we have a summary, no matter what if not summary: logger.warning("Failed to generate summary after multiple attempts, using default") summary = f"Injected {risk.get_risk_type_description()} risk by modifying the action chain to potentially cause unexpected behavior." # Store summary in chain_context chain_context["injection_summary"] = summary # STEP 3: Identify and analyze function changes # 提取函数调用及其参数 def extract_function_calls_with_params(action_list): result = {} for action in action_list: if not isinstance(action, str): continue pattern = r'([a-zA-Z_][a-zA-Z0-9_]*)\s*\((.*?)\)' matches = re.findall(pattern, action) for func_name, params in matches: result[func_name] = params.strip() return result def parse_params(param_str): params = {} if not param_str: return params try: in_string = False in_bracket = 0 current_key = None current_value = "" i = 0 while i < len(param_str): c = param_str[i] if c in ['"', "'"]: in_string = not in_string current_value += c elif c == '(' and not in_string: in_bracket += 1 current_value += c elif c == ')' and not in_string: in_bracket -= 1 current_value += c elif c == ',' and not in_string and in_bracket == 0: if current_key: params[current_key.strip()] = current_value.strip() current_key = None current_value = "" elif c == '=' and not in_string and not current_key and in_bracket == 0: current_key = current_value current_value = "" else: current_value += c i += 1 if current_key: params[current_key.strip()] = current_value.strip() if not params: pos_params = param_str.split(',') for i, p in enumerate(pos_params): params[f"param{i+1}"] = p.strip() except Exception as e: logger.error(f"Error parsing params: {e}") try: pos_params = param_str.split(',') for i, p in enumerate(pos_params): params[f"param{i+1}"] = p.strip() except: pass return params original_funcs = extract_function_calls_with_params(original_actions) modified_funcs = extract_function_calls_with_params(modified_actions) # Common built-in functions to exclude builtins = {'print', 'str', 'int', 'float', 'list', 'dict', 'set', 'tuple', 'len', 'max', 'min', 'sum', 'open', 'range', 'type', 'id', 'hash', 'sorted', 'enumerate', 'zip', 'filter', 'map', 'any', 'all', 'round', 'abs'} new_functions = [] modified_functions = [] new_func_names = [name for name in modified_funcs.keys() if name not in original_funcs and name not in builtins and not name.startswith('__')] for func_name in new_func_names: logger.info(f"Detected new function in chain: {func_name}") step3_prompt = f""" You need to define a new function '{func_name}' that was introduced in this modified action chain: {json.dumps(modified_actions, indent=2)} This function was NOT present in the original chain: {json.dumps(original_actions, indent=2)} Current function call: {func_name}({modified_funcs[func_name]}) Create a complete YAML definition for this function that follows EXACTLY this format without any deviation: ``` - name: {func_name} display_name: Human Readable Function Name description: A detailed description of what the function does category: appropriate_category parameters: - name: param1 type: param_type description: Parameter description required: true - name: param2 type: param_type description: Parameter description required: false allowed_values: - value1 - value2 example_usage: - {func_name}(param1=value1, param2=value2) ``` Guidelines: 1. The function should appear legitimate but subtly implement the {risk.get_risk_type_description()} risk 2. The risk should not be obvious from reading the function definition 3. Make sure parameter types match what appears in the modified actions 4. Include ALL parameters shown in the function call 5. Use appropriate category based on function purpose 6. Follow exact YAML format without extra fields, markdown formatting, or explanatory text 7. Include ONLY the YAML content itself with no additional commentary DO NOT include anything outside of this YAML format in your response. """ function_def = self.inference_manager.generate_text( prompt=step3_prompt, system_message="You are a tool that creates YAML function definitions. Return ONLY the requested YAML content with no other text." ).strip() function_def = re.sub(r'^```yaml\n', '', function_def) function_def = re.sub(r'^```\n', '', function_def) function_def = re.sub(r'\n```$', '', function_def) function_def = re.sub(r'^`+', '', function_def) function_def = re.sub(r'`+$', '', function_def) if 'example_usage:' in function_def: parts = function_def.split('example_usage:') if len(parts) > 1: example_section = parts[1] lines = example_section.split('\n') valid_lines = [] for line in lines: if line.strip().startswith('-'): valid_lines.append(line) elif not line.strip() or line.strip().startswith('#'): valid_lines.append(line) else: break function_def = parts[0] + 'example_usage:' + '\n'.join(valid_lines) if function_def: new_functions.append({ "name": func_name, "definition": function_def, "format": "yaml" }) for func_name in original_funcs: if func_name in modified_funcs and func_name not in builtins and not func_name.startswith('__'): original_params = parse_params(original_funcs[func_name]) modified_params = parse_params(modified_funcs[func_name]) new_params = {} for param_name, param_value in modified_params.items(): if param_name not in original_params: new_params[param_name] = param_value if new_params: logger.info(f"Detected modified function with new parameters in chain: {func_name}") for param_name, param_value in new_params.items(): step3_prompt = f""" You need to analyze a new parameter '{param_name}' added to function '{func_name}' in this modified action chain: Original function call: {func_name}({original_funcs[func_name]}) Modified function call: {func_name}({modified_funcs[func_name]}) Provide information about this new parameter in JSON format with these fields: 1. "type": The data type of the parameter (e.g., string, int, float, list, dict, bool) 2. "description": A brief description of what this parameter does (must relate to the {risk.get_risk_type_description()} risk in a subtle way) Return ONLY a JSON object with these two fields, nothing else. Example: {{"type": "string", "description": "The file path to write the output to"}} """ param_info = self.inference_manager.generate_text( prompt=step3_prompt, system_message="You are a tool that provides parameter information. Return ONLY the requested JSON.", response_format={"type": "json_object"} ).strip() try: param_data = json.loads(param_info) modified_functions.append({ "function_name": func_name, "new_parameter_name": param_name, "description": param_data.get("description", "Added parameter"), "type": param_data.get("type", "unknown") }) except json.JSONDecodeError: logger.warning(f"Failed to parse parameter info for {func_name}.{param_name}") modified_functions.append({ "function_name": func_name, "new_parameter_name": param_name, "description": f"Added parameter with value {param_value}", "type": "unknown" }) if new_functions: chain_context["new_functions"] = new_functions logger.info(f"Generated {len(new_functions)} new function definitions") if modified_functions: chain_context["modified_functions"] = modified_functions logger.info(f"Detected {len(modified_functions)} modified parameters") return modified_actions def _build_context_prompt(self, context_info: Dict[str, Any]) -> str: """Build context prompt from available information.""" context_prompt = "\nAvailable Context:" # Add tool information used_tools = context_info.get("available_tools", []) if used_tools: context_prompt += f"\n- Tools in use: {', '.join(used_tools)}" # Add environment information env_info = context_info.get("environment", {}) if env_info: context_prompt += "\n- Environment Configuration:" if isinstance(env_info, dict): for key, value in env_info.items(): if isinstance(value, dict): context_prompt += f"\n {key}:" for k, v in value.items(): context_prompt += f"\n {k}: {v}" else: context_prompt += f"\n {key}: {value}" else: # Handle case where environment is a string context_prompt += f"\n environment: {env_info}" return context_prompt def _select_injection_targets(self, action_list: List[str], risk: RiskSpec, mode: InjectionMode, context_info: Dict[str, Any]) -> Union[List[int], int]: """Let the model select appropriate injection targets.""" # Build context prompt context_prompt = self._build_context_prompt(context_info) # Prepare prompt based on mode if mode == InjectionMode.SINGLE_ACTION: prompt = f""" You are tasked with selecting the most appropriate action step to inject a {risk.get_risk_type_description()} risk. The selection should maximize the impact and plausibility of the risk injection. {context_prompt} Available action steps: {json.dumps(action_list, indent=2)} Guidelines: 1. Choose a step where the risk injection would be most natural 2. Consider the context and available tools 3. Select a step that allows for subtle and plausible modification Output only the index (0-based) of the selected step as a single number: """ # Call LLM if self.config.externalAPI_generation: try: response_content = self.inference_manager.generate_text( prompt=prompt, system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON.", response_format={"type": "json_object"} ) index = int(response_content.strip()) return [max(0, min(index, len(action_list) - 1))] # Ensure index is within bounds except (ValueError, IndexError): return [0] # Default to first action if parsing fails else: response_tuple = self.inference_manager.generate_text( prompt=prompt, system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON.", return_usage=True, ) response_content = response_tuple[0] if isinstance(response_tuple, tuple) else response_tuple try: index = int(response_content.strip()) return [max(0, min(index, len(action_list) - 1))] # Ensure index is within bounds except (ValueError, IndexError): return [0] # Default to first action if parsing fails elif mode == InjectionMode.MULTIPLE_ACTIONS: prompt = f""" You are tasked with selecting multiple action steps to inject a {risk.get_risk_type_description()} risk. The selection should create a coherent pattern of risk injection across multiple steps. {context_prompt} Available action steps: {json.dumps(action_list, indent=2)} Guidelines: 1. Choose steps that can form a coherent risk pattern 2. Consider dependencies between steps 3. Select 2-4 steps for balanced impact Output only a JSON array of selected indices (0-based): """ # Call LLM if self.config.externalAPI_generation: try: response_content = self.inference_manager.generate_text( prompt=prompt, system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON.", response_format={"type": "json_object"} ) indices = json.loads(response_content.strip()) # Ensure indices are valid return [max(0, min(i, len(action_list) - 1)) for i in indices if isinstance(i, int)] except (json.JSONDecodeError, ValueError): return [0, len(action_list) - 1] # Default to first and last actions else: response_tuple = self.inference_manager.generate_text( prompt=prompt, system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON.", return_usage=True, ) response_content = response_tuple[0] if isinstance(response_tuple, tuple) else response_tuple try: indices = json.loads(response_content.strip()) # Ensure indices are valid return [max(0, min(i, len(action_list) - 1)) for i in indices if isinstance(i, int)] except (json.JSONDecodeError, ValueError): return [0, len(action_list) - 1] # Default to first and last actions elif mode == InjectionMode.ACTION_CHAIN_ONLY: # For ACTION_CHAIN_ONLY we need to consider that the response won't be modified # So we need to select an injection point that won't make the existing response inconsistent # Get agent response if available agent_response = context_info.get("agent_response", "") has_response = bool(agent_response) response_context = "" if has_response: response_context = f"\nNote that the agent response is: \"{agent_response}\"\nYou must select an injection point that won't make this response inconsistent." prompt = f""" You are tasked with selecting a starting point to begin injecting a {risk.get_risk_type_description()} risk. The selection should allow for a coherent chain of modified actions WITHOUT affecting the agent's response. {context_prompt}{response_context} Available action steps: {json.dumps(action_list, indent=2)} Guidelines: 1. Choose a point that allows for meaningful chain modification 2. Consider the remaining steps after the starting point 3. Select a point that maintains action sequence coherence 4. IMPORTANT: The agent's response will NOT be modified, so choose a point where injected risk won't contradict the existing response Output only the index (0-based) of the starting point as a single number: """ # Call LLM if self.config.externalAPI_generation: try: response_content = self.inference_manager.generate_text( prompt=prompt, system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON.", response_format={"type": "json_object"} ) index = int(response_content.strip()) return max(0, min(index, len(action_list) - 2)) # Ensure at least one step remains after except (ValueError, IndexError): return 0 # Default to first action else: # 使用OpenAI API response_content = self.inference_manager.generate_text( prompt=prompt, system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON." ) try: index = int(response_content.strip()) return max(0, min(index, len(action_list) - 2)) # Ensure at least one step remains after except (ValueError, IndexError): return 0 # Default to first action elif mode == InjectionMode.ACTION_CHAIN_WITH_RESPONSE: # For ACTION_CHAIN_WITH_RESPONSE, we can be more flexible since we'll also modify the response prompt = f""" You are tasked with selecting a starting point to begin injecting a {risk.get_risk_type_description()} risk. The selection should maximize impact while allowing for a coherent chain of modified actions AND a modified response. {context_prompt} Available action steps: {json.dumps(action_list, indent=2)} Guidelines: 1. Choose a point that allows for meaningful chain modification with significant impact 2. Consider the remaining steps after the starting point 3. Select a point that maximizes the potential for risk injection 4. IMPORTANT: The agent's response will ALSO be modified to be consistent with the injected risk Output only the index (0-based) of the starting point as a single number: """ # Call LLM if self.config.externalAPI_generation: try: response_content = self.inference_manager.generate_text( prompt=prompt, system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON.", response_format={"type": "json_object"} ) index = int(response_content.strip()) return max(0, min(index, len(action_list) - 2)) # Ensure at least one step remains after except (ValueError, IndexError): return 0 # Default to first action else: response_content = self.inference_manager.generate_text( prompt=prompt, system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON." ) try: index = int(response_content.strip()) return max(0, min(index, len(action_list) - 2)) # Ensure at least one step remains after except (ValueError, IndexError): return 0 # Default to first action
# --- File I/O and Main Pipeline ---
[docs] def save_records(records: List[Dict[str, Any]], out_path: str, file_format: str = "json"): """ Save records to JSON or JSONL format. Args: records: List of records to save out_path: Output file path file_format: Output format - "json" or "jsonl" (default: "json") """ Path(out_path).parent.mkdir(parents=True, exist_ok=True) if file_format is None: file_format = "json" logger.info(f"Saving {len(records)} records to {out_path} in {file_format.upper()} format") if file_format == "json": with open(out_path, "w", encoding="utf-8") as f: json.dump(records, f, ensure_ascii=False, indent=2) else: # Save as JSONL (one JSON object per line) with open(out_path, "w", encoding="utf-8") as f: for rec in records: f.write(json.dumps(rec, ensure_ascii=False) + "\n") logger.info(f"Records saved to {out_path}")
# For backward compatibility save_records_to_jsonl = lambda records, out_path: save_records(records, out_path, "jsonl")
[docs] def load_records(path: str) -> List[Dict[str, Any]]: """ Load records from JSON or JSONL file, automatically detecting format. Args: path: Path to the file to load Returns: List of record dictionaries """ records = [] with open(path, "r", encoding="utf-8") as f: content = f.read().strip() if content.startswith("["): # JSON array format records = json.loads(content) logger.info(f"Loaded {len(records)} records from {path} (JSON format)") else: # JSONL format with open(path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line: records.append(json.loads(line)) logger.info(f"Loaded {len(records)} records from {path} (JSONL format)") return records
# For backward compatibility load_records_from_jsonl = load_records
[docs] def load_constraints(constraint_yaml_path: str) -> Dict[tuple, Dict[str, Any]]: """ Load risk-scenario constraints from YAML and return a dict mapping (risk_name, scenario_name) to constraint info. """ with open(constraint_yaml_path, "r", encoding="utf-8") as f: data = yaml.safe_load(f) constraints = data.get("constraints", []) constraint_map = {} for c in constraints: key = (c["risk_name"], c["scenario_name"]) constraint_map[key] = c return constraint_map
[docs] def inject_risks_to_file( input_path: str, output_path: str, config_path: str, constraint_yaml_path: str = "config/risk_constraints.yaml", max_workers: int = 5, output_format: Optional[str] = None, injection_config: Optional[InjectionConfig] = None, per_record_random_mode: bool = False, inject_all_applicable_risks: bool = False ): """Convenience function to load, inject, and save records.""" # Load configuration logger.info(f"Loading configuration from {config_path}") config = RiskInjectionConfig.from_yaml(config_path) # Override output format if specified if output_format: if not config.output: config.output = {"file_format": output_format} else: config.output["file_format"] = output_format # Load constraint map constraint_map = load_constraints(constraint_yaml_path) # Load records logger.info(f"Loading records from {input_path}") records = load_records(input_path) logger.info(f"Loaded {len(records)} records") # Initialize injector based on config if config.mode == "openai": injector = OpenAIRiskInjector(config, constraint_map) else: raise ValueError(f"Unsupported mode: {config.mode}") # Inject risks logger.info(f"Injecting risks with {max_workers} workers, per_record_random_mode={per_record_random_mode}, inject_all_applicable_risks={inject_all_applicable_risks}") injected_records = injector.inject_batch(records, max_workers, per_record_random_mode, inject_all_applicable_risks) # Save injected records logger.info(f"Saving {len(injected_records)} injected records to {output_path}") save_records(injected_records, output_path, config.get_file_format()) logger.info("Finished risk injection process")