"""
Risk Injection Module
This module injects risks into harmless agent action records using LLMs.
"""
from typing import List, Dict, Any, Optional, Union
from pathlib import Path
from pydantic import BaseModel, Field
import yaml
import random
import json
import time
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from loguru import logger
import re
import requests
from AuraGen.inference import InferenceManager, OpenAIConfig, externalAPIConfig
import os
from .injection_modes import InjectionMode, InjectionConfig
# --- Config Models ---
[docs]
class RiskSpec(BaseModel):
name: str
description: str
injection_probability: float = 0.1
target: str # "agent_action" or "agent_response"
prompt_template: str
category: str = Field("unknown", description="Risk category (e.g., 'hallucination', 'privacy_leak', etc.)")
injection_modes: List[str] = Field(
default_factory=lambda: ["single_action"],
description="Supported injection modes for this risk"
)
chain_prompt_template: Optional[str] = Field(
None,
description="Template for chain modification prompts"
)
response_prompt_template: Optional[str] = Field(
None,
description="Template for response modification prompts"
)
[docs]
def get_risk_type_description(self) -> str:
"""Get a human-readable description of the risk type."""
return self.category.replace("_", " ").title()
[docs]
def get_prompt_for_mode(self, mode: str, is_response: bool = False) -> str:
"""Get the appropriate prompt template for the given mode."""
if is_response and self.response_prompt_template:
return self.response_prompt_template
if mode in ["action_chain_with_response", "action_chain_only"] and self.chain_prompt_template:
return self.chain_prompt_template
return self.prompt_template
[docs]
class RiskInjectionConfig(BaseModel):
mode: str = Field("openai", pattern="^(openai|local)$")
batch_size: int = 10
externalAPI_generation: bool = Field(False, description="Whether to use externalAPI internal inference API")
openai: Optional[OpenAIConfig] = None
externalAPI: Optional[externalAPIConfig] = None
risks: List[RiskSpec]
output: Optional[Dict[str, str]] = Field(default_factory=lambda: {"file_format": "json"})
auto_select_targets: bool = False
[docs]
@classmethod
def from_yaml(cls, yaml_path: str) -> "RiskInjectionConfig":
with open(yaml_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
inj = data.get("injection", {})
openai_cfg = None
externalAPI_cfg = None
if "openai" in data:
openai_section = data.get("openai", {})
if isinstance(openai_section, dict) and "api_key" not in openai_section and "api_key_type" in openai_section:
mapping = {
"openai_api_key": "OPENAI_API_KEY",
"deepinfra_api_key": "DEEPINFRA_API_KEY",
}
key_type = openai_section.get("api_key_type")
if key_type not in mapping:
raise ValueError(
f"Unknown api_key_type: {key_type}. Expected one of: {', '.join(mapping.keys())}"
)
env_name = mapping[key_type]
value = os.getenv(env_name, "").strip()
if not value:
# Fallback to project .env
try:
project_root = Path(__file__).resolve().parents[1]
env_path = project_root / ".env"
if env_path.exists():
for line in env_path.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
k, v = line.split("=", 1)
if k.strip() == env_name:
value = v.strip().strip('"')
break
except Exception:
pass
if not value:
raise ValueError(
f"Environment variable '{env_name}' not set for api_key_type '{key_type}'. "
f"Consider running: python config/configure_api_keys.py"
)
openai_section["api_key"] = value
openai_cfg = OpenAIConfig(**openai_section)
if "externalAPI" in data:
externalAPI_cfg = externalAPIConfig(**data.get("externalAPI", {}))
risks = data.get("risks", [])
return cls(
mode=inj.get("mode", "openai"),
batch_size=inj.get("batch_size", 10),
externalAPI_generation=inj.get("externalAPI_generation", False),
openai=openai_cfg,
externalAPI=externalAPI_cfg,
risks=risks,
auto_select_targets=inj.get("auto_select_targets", False)
)
# --- Injector Base ---
[docs]
class RiskInjectorBase:
[docs]
def __init__(self, config: RiskInjectionConfig, constraint_map: Optional[Dict[tuple, Dict[str, Any]]] = None):
self.config = config
self.risks = [RiskSpec(**r) if isinstance(r, dict) else r for r in config.risks]
self.constraint_map = constraint_map or {}
[docs]
def is_risk_applicable(self, risk_name: str, scenario_name: str) -> Optional[Dict[str, Any]]:
"""
Return constraint info if risk is applicable to the scenario, else None.
"""
key = (risk_name, scenario_name)
constraint = self.constraint_map.get(key)
if constraint and constraint.get("compatibility", True):
return constraint
return None
[docs]
def inject_risk(self, record: Dict[str, Any], risk: RiskSpec, constraint: Dict[str, Any], injection_config: Optional[InjectionConfig] = None) -> Dict[str, Any]:
"""
Inject risk into a record using the specified injection mode.
Enhanced with robust error handling for external datasets.
"""
try:
# Validate and sanitize record
record = self._validate_and_sanitize_record(record)
# Create new record to avoid modifying original
new_record = record.copy()
# Extract original content and context
original_content = record.get("agent_action", [])
if isinstance(original_content, str):
original_content = [original_content]
# Get context information
context_info = record.get("metadata", {}).get("context", {})
if not context_info:
context_info = {"available_tools": [], "environment": {}}
# Add agent_response to context_info if available
# This is important for ACTION_CHAIN_ONLY mode to consider existing response
if "agent_response" in record and record["agent_response"]:
context_info["agent_response"] = record["agent_response"]
# Create default injection config if none provided
if injection_config is None:
injection_config = InjectionConfig()
# Auto-select targets if needed
if injection_config.auto_select_targets:
if injection_config.mode in [InjectionMode.SINGLE_ACTION, InjectionMode.MULTIPLE_ACTIONS]:
injection_config.target_indices = self._select_injection_targets(
original_content,
risk,
injection_config.mode,
context_info
)
else: # Chain modes
injection_config.chain_start_index = self._select_injection_targets(
original_content,
risk,
injection_config.mode,
context_info
)
# Apply injection based on mode
if injection_config.mode == InjectionMode.SINGLE_ACTION:
target_index = injection_config.target_indices[0] if injection_config.target_indices else 0
new_action_list = self._inject_single_action(original_content, target_index, risk, context_info)
modified_response = None
elif injection_config.mode == InjectionMode.MULTIPLE_ACTIONS:
target_indices = injection_config.target_indices or [0]
new_action_list = self._inject_multiple_actions(original_content, target_indices, risk, context_info)
modified_response = None
elif injection_config.mode in [InjectionMode.ACTION_CHAIN_WITH_RESPONSE, InjectionMode.ACTION_CHAIN_ONLY]:
with_response = injection_config.mode == InjectionMode.ACTION_CHAIN_WITH_RESPONSE
start_index = injection_config.chain_start_index if injection_config.chain_start_index is not None else 0
new_action_list = self._inject_action_chain(
original_content,
start_index,
risk,
context_info,
with_response
)
# Modify response if needed
if with_response:
original_response = record.get("agent_response", "")
new_record["original_agent_response"] = original_response
# Add agent_response to context for response modification
response_context = context_info.copy()
response_context["agent_response"] = original_response
modified_response = self._get_injected_step(
original_response,
risk,
response_context,
injection_config.mode
)
else:
modified_response = None
# Update record with modified content
new_record["agent_action"] = new_action_list
if modified_response is not None:
new_record["agent_response"] = modified_response
# Update metadata safely
self._update_metadata_safely(new_record, risk, injection_config, original_content, new_action_list, context_info, constraint, modified_response)
return new_record
except Exception as e:
logger.error(f"Exception in inject_risk: {e}")
logger.exception("Full traceback in inject_risk:")
return record
def _validate_and_sanitize_record(self, record: Dict[str, Any]) -> Dict[str, Any]:
"""
Validate and sanitize a record, ensuring it has the minimum required structure.
This is especially important for records from external datasets.
"""
if not isinstance(record, dict):
raise ValueError(f"Record must be a dictionary, got {type(record)}")
# Ensure required fields exist with sensible defaults
sanitized = dict(record)
# scenario_name
if "scenario_name" not in sanitized or not sanitized["scenario_name"]:
sanitized["scenario_name"] = "external_dataset"
# user_request
if "user_request" not in sanitized:
sanitized["user_request"] = ""
# agent_action - this is critical
if "agent_action" not in sanitized:
sanitized["agent_action"] = []
elif not isinstance(sanitized["agent_action"], list):
# Try to convert to list if it's a string
if isinstance(sanitized["agent_action"], str):
# Simple conversion - split by common delimiters
action_str = sanitized["agent_action"]
for delimiter in ['\n', '.', ';', '|']:
if delimiter in action_str:
steps = [step.strip() for step in action_str.split(delimiter)]
sanitized["agent_action"] = [step for step in steps if step]
break
else:
# No delimiter found, treat as single action
sanitized["agent_action"] = [action_str.strip()] if action_str.strip() else []
else:
# Convert other types to string list
sanitized["agent_action"] = [str(sanitized["agent_action"])]
# agent_response
if "agent_response" not in sanitized:
sanitized["agent_response"] = ""
# metadata
if "metadata" not in sanitized:
sanitized["metadata"] = {}
elif not isinstance(sanitized["metadata"], dict):
# If metadata is not a dict, create a new one and store the original value
original_metadata = sanitized["metadata"]
sanitized["metadata"] = {"original_metadata": original_metadata}
return sanitized
def _extract_context_info(self, record: Dict[str, Any]) -> Dict[str, Any]:
"""
Extract context information from a record with robust fallbacks.
"""
# Try different possible locations for context info
context_info = {}
# First, try the standard location
if "metadata" in record and isinstance(record["metadata"], dict):
if "context" in record["metadata"] and isinstance(record["metadata"]["context"], dict):
context_info = record["metadata"]["context"].copy()
# If no context found, try to infer from other fields
if not context_info:
# Try to extract tools from action text
available_tools = self._extract_tools_from_actions(record.get("agent_action", []))
if available_tools:
context_info["available_tools"] = available_tools
# Set a default environment
context_info["environment"] = record.get("scenario_name", "unknown")
# Add any other metadata fields that might be useful
metadata = record.get("metadata", {})
if isinstance(metadata, dict):
for key, value in metadata.items():
if key not in context_info and isinstance(value, (str, list, dict)):
context_info[key] = value
return context_info
def _extract_tools_from_actions(self, actions: List[str]) -> List[str]:
"""
Extract tool names from action text using pattern matching.
"""
tools = set()
for action in actions:
if not isinstance(action, str):
continue
# Look for function call patterns: tool_name(...)
import re
function_calls = re.findall(r'([a-zA-Z_][a-zA-Z0-9_]*)\s*\(', action)
tools.update(function_calls)
# Look for common tool keywords
common_tools = [
'search', 'send_email', 'get', 'post', 'read', 'write', 'create', 'delete',
'analyze', 'process', 'generate', 'calculate', 'validate', 'execute'
]
action_lower = action.lower()
for tool in common_tools:
if tool in action_lower:
tools.add(tool)
return list(tools)
def _update_metadata_safely(self, record: Dict[str, Any], risk: RiskSpec, injection_config: InjectionConfig,
original_content: List[str], new_action_list: List[str], context_info: Dict[str, Any],
constraint: Dict[str, Any], modified_response: Optional[str]):
"""
Safely update metadata without overwriting existing important information.
"""
# Ensure metadata exists
if "metadata" not in record:
record["metadata"] = {}
elif not isinstance(record["metadata"], dict):
record["metadata"] = {"original_metadata": record["metadata"]}
# Ensure risk_injection list exists
if "risk_injection" not in record["metadata"]:
record["metadata"]["risk_injection"] = []
elif not isinstance(record["metadata"]["risk_injection"], list):
record["metadata"]["risk_injection"] = []
# Extract injection summary if available
injection_summary = context_info.get("injection_summary", "")
# Extract modified functions (including new and parameter-changed functions)
modified_functions = context_info.get("modified_functions", [])
# For backward compatibility - also check for new_functions
legacy_new_functions = context_info.get("new_functions", [])
# Store injection metadata
injection_info = {
"risk_name": risk.name,
"description": risk.description,
"injection_mode": injection_config.mode.value,
"target_indices": injection_config.target_indices,
"chain_start_index": injection_config.chain_start_index,
"auto_selected": injection_config.auto_select_targets,
"has_response_modification": modified_response is not None,
"original_actions": original_content,
"modified_actions": new_action_list,
"injection_summary": injection_summary, # Add the summary to metadata
"context": {
"available_tools": context_info.get("available_tools", []),
"environment": context_info.get("environment", {})
},
"constraint": constraint,
"timestamp": int(time.time())
}
if modified_functions:
injection_info["modified_functions"] = modified_functions
new_funcs = [f for f in modified_functions if f.get("is_new", False)]
if new_funcs:
injection_info["new_functions"] = new_funcs
elif legacy_new_functions:
injection_info["new_functions"] = legacy_new_functions
if modified_response is not None:
injection_info["modified_agent_response"] = modified_response
if "original_agent_response" in record:
injection_info["original_agent_response"] = record["original_agent_response"]
record["metadata"]["risk_injection"].append(injection_info)
record["metadata"]["risk_injection_time"] = int(time.time())
[docs]
def inject_batch(self, records: List[Dict[str, Any]], max_workers: int = 5, per_record_random_mode: bool = False, inject_all_applicable_risks: bool = False) -> List[Dict[str, Any]]:
# For each record, find all applicable risks, inject all or randomly pick one
tasks = []
for rec in records:
scenario_name = rec.get("scenario_name", "unknown")
applicable = []
for risk in self.risks:
constraint = self.is_risk_applicable(risk.name, scenario_name)
if constraint:
applicable.append((risk, constraint))
if applicable:
if inject_all_applicable_risks:
# Inject all applicable risks - create a copy of record for each risk
for risk, constraint in applicable:
# Use a deep copy of the original record for each risk
import copy
record_copy = copy.deepcopy(rec)
# If per_record_random_mode is True, create a random injection config for each record
if per_record_random_mode:
mode = random.choice([
InjectionMode.SINGLE_ACTION,
InjectionMode.MULTIPLE_ACTIONS,
InjectionMode.ACTION_CHAIN_WITH_RESPONSE,
InjectionMode.ACTION_CHAIN_ONLY
])
injection_config = InjectionConfig(
mode=mode,
auto_select_targets=True,
modify_response=(mode == InjectionMode.ACTION_CHAIN_WITH_RESPONSE)
)
tasks.append((record_copy, risk, constraint, injection_config))
else:
tasks.append((record_copy, risk, constraint, None))
else:
# Randomly pick one risk to inject for this record (original behavior)
risk, constraint = random.choice(applicable)
# If per_record_random_mode is True, create a random injection config for each record
if per_record_random_mode:
mode = random.choice([
InjectionMode.SINGLE_ACTION,
InjectionMode.MULTIPLE_ACTIONS,
InjectionMode.ACTION_CHAIN_WITH_RESPONSE,
InjectionMode.ACTION_CHAIN_ONLY
])
injection_config = InjectionConfig(
mode=mode,
auto_select_targets=True,
modify_response=(mode == InjectionMode.ACTION_CHAIN_WITH_RESPONSE)
)
tasks.append((rec, risk, constraint, injection_config))
else:
tasks.append((rec, risk, constraint, None))
else:
# No applicable risk, just keep the record as is
tasks.append((rec, None, None, None))
injected = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for task in tasks:
if len(task) == 4: # New format with injection_config
rec, risk, constraint, injection_config = task
if risk is not None:
futures.append(executor.submit(self.inject_risk, rec, risk, constraint, injection_config))
else:
injected.append(rec)
else: # Old format compatibility
rec, risk, constraint = task
if risk is not None:
futures.append(executor.submit(self.inject_risk, rec, risk, constraint))
else:
injected.append(rec)
with tqdm(total=len(futures), desc="Injecting risks") as pbar:
for future in as_completed(futures):
try:
result = future.result()
injected.append(result)
pbar.update(1)
except Exception as e:
logger.error(f"Risk injection error: {e}")
pbar.update(1)
return injected
# --- OpenAI Injector ---
[docs]
class OpenAIRiskInjector(RiskInjectorBase):
"""
Risk injector using OpenAI API with enhanced context awareness.
"""
[docs]
def __init__(self, config: RiskInjectionConfig, constraint_map: Optional[Dict[tuple, Dict[str, Any]]] = None):
super().__init__(config, constraint_map)
self.inference_manager = InferenceManager(
use_internal_inference=config.externalAPI_generation,
openai_config=config.openai,
externalAPI_config=config.externalAPI
)
logger.info(f"Initialized OpenAIRiskInjector with {'externalAPI API' if config.externalAPI_generation else 'OpenAI API'}")
def _inject_single_action(self, action_list: List[str], target_index: int, risk: RiskSpec, context_info: Dict[str, Any]) -> List[str]:
"""Inject risk into a single action."""
if not 0 <= target_index < len(action_list):
logger.warning(f"Invalid target index {target_index} for action list of length {len(action_list)}")
return action_list
selected_step = action_list[target_index]
injected_step = self._get_injected_step(selected_step, risk, context_info, InjectionMode.SINGLE_ACTION)
new_action_list = action_list.copy()
new_action_list[target_index] = injected_step
return new_action_list
def _inject_multiple_actions(self, action_list: List[str], target_indices: List[int], risk: RiskSpec, context_info: Dict[str, Any]) -> List[str]:
"""Inject risk into multiple actions."""
new_action_list = action_list.copy()
for idx in target_indices:
if 0 <= idx < len(action_list):
selected_step = action_list[idx]
injected_step = self._get_injected_step(selected_step, risk, context_info, InjectionMode.MULTIPLE_ACTIONS)
new_action_list[idx] = injected_step
else:
logger.warning(f"Invalid target index {idx} for action list of length {len(action_list)}")
return new_action_list
def _inject_action_chain(self, action_list: List[str], start_index: int, risk: RiskSpec, context_info: Dict[str, Any], with_response: bool = False) -> List[str]:
"""Inject risk into action chain starting from given index."""
if not 0 <= start_index < len(action_list):
logger.warning(f"Invalid start index {start_index} for action list of length {len(action_list)}")
return action_list
# Prepare the chain context
chain_context = {
"original_actions": action_list[start_index:],
"available_tools": context_info.get("available_tools", []),
"environment": context_info.get("environment", {})
}
# For ACTION_CHAIN_ONLY mode, we need to pass the existing response
# so the chain modification can be consistent with it
mode = InjectionMode.ACTION_CHAIN_WITH_RESPONSE if with_response else InjectionMode.ACTION_CHAIN_ONLY
if not with_response and "agent_response" in context_info:
chain_context["agent_response"] = context_info["agent_response"]
# Get the modified chain
injected_chain = self._get_injected_chain(chain_context, risk, mode)
# Create the new action list
new_action_list = action_list[:start_index] + injected_chain
return new_action_list
def _get_injected_step(self, step: str, risk: RiskSpec, context_info: Dict[str, Any], mode: InjectionMode) -> str:
"""Get injected version of a single step with improved multi-step approach."""
# Build context prompt
context_prompt = self._build_context_prompt(context_info)
# Determine if this is a response modification
is_response = mode == InjectionMode.ACTION_CHAIN_WITH_RESPONSE and isinstance(step, str) and "agent_response" in context_info
# Add response consistency guidance if needed
response_guidance = ""
if not is_response and "agent_response" in context_info and mode != InjectionMode.ACTION_CHAIN_WITH_RESPONSE:
response_guidance = f"""
IMPORTANT: The agent's response will NOT be modified. Your modification must remain consistent with this existing response:
"{context_info['agent_response']}"
Therefore, ensure your injected risk doesn't contradict the response or create obvious inconsistencies.
"""
# STEP 1: Generate the modified content only
step1_prompt = f"""
You are tasked with subtly modifying an {'agent response' if is_response else 'action step'} to introduce a {risk.get_risk_type_description()} risk.
The modification should appear legitimate and not be explicitly marked as malicious or invalid.
{risk.get_prompt_for_mode(mode.value, is_response=is_response)}
{context_prompt}
{response_guidance}
Guidelines for SUBTLE and COVERT modification:
1. Maintain the basic structure and apparent functionality
2. Use realistic-looking {'language' if is_response else 'function names and parameters'}
3. The modification must be subtle and difficult to detect at first glance
4. AVOID directly stating malicious intent or using obvious red flags
5. Do not use comments that reveal the risk (like "this will crash the system")
6. Make changes look like legitimate functionality, bugs, or innocent mistakes
7. Consider the environment context in your modification
8. You can either introduce new functions or add new parameters to existing functions
Original {'response' if is_response else 'step'}: {step}
Return ONLY the modified content without any explanation, summary, or additional text.
"""
# Generate modified content
logger.info("Generating modified content...")
modified_content = self.inference_manager.generate_text(
prompt=step1_prompt,
system_message="You are a tool that modifies text according to instructions. Return ONLY the modified content."
).strip()
# Fallback if no meaningful modification
if not modified_content or modified_content == step:
logger.warning("No meaningful modification was made, using original step")
modified_content = step
# STEP 2: Generate a summary of the injection - 加强措施确保一定有摘要
step2_prompt = f"""
You have modified an {'agent response' if is_response else 'action step'} to introduce a {risk.get_risk_type_description()} risk.
Original: {step}
Modified: {modified_content}
Provide a detailed summary of what risk was injected and how it works. Be specific about the nature of the risk and its potential impact.
Your summary should be 1-3 sentences maximum, and should NOT directly mention that this is a "malicious" modification.
Return ONLY the summary without any additional explanation, introduction, or formatting.
"""
# Generate summary with multiple retries to ensure we get one
logger.info("Generating injection summary...")
summary = ""
max_attempts = 3
for attempt in range(max_attempts):
summary = self.inference_manager.generate_text(
prompt=step2_prompt,
system_message="You are a tool that provides concise summaries. Return ONLY the requested summary."
).strip()
if summary:
logger.info(f"Got summary on attempt {attempt+1}")
break
else:
logger.warning(f"Empty summary on attempt {attempt+1}, retrying...")
# Ensure we have a summary, no matter what
if not summary:
logger.warning("Failed to generate summary after multiple attempts, using default")
summary = f"Injected {risk.get_risk_type_description()} risk by modifying the original {'response' if is_response else 'step'} to potentially cause unexpected behavior."
# Store summary in context_info
context_info["injection_summary"] = summary
# STEP 3: Identify and analyze function changes
if not is_response: # Only analyze action steps for function changes
def extract_function_calls_with_params(text):
pattern = r'([a-zA-Z_][a-zA-Z0-9_]*)\s*\((.*?)\)'
matches = re.findall(pattern, text)
result = {}
for func_name, params in matches:
result[func_name] = params.strip()
return result
def parse_params(param_str):
params = {}
if not param_str:
return params
try:
in_string = False
in_bracket = 0
current_key = None
current_value = ""
i = 0
while i < len(param_str):
c = param_str[i]
if c in ['"', "'"]:
in_string = not in_string
current_value += c
elif c == '(' and not in_string:
in_bracket += 1
current_value += c
elif c == ')' and not in_string:
in_bracket -= 1
current_value += c
elif c == ',' and not in_string and in_bracket == 0:
if current_key:
params[current_key.strip()] = current_value.strip()
current_key = None
current_value = ""
elif c == '=' and not in_string and not current_key and in_bracket == 0:
current_key = current_value
current_value = ""
else:
current_value += c
i += 1
if current_key:
params[current_key.strip()] = current_value.strip()
if not params:
pos_params = param_str.split(',')
for i, p in enumerate(pos_params):
params[f"param{i+1}"] = p.strip()
except Exception as e:
logger.error(f"Error parsing params: {e}")
try:
pos_params = param_str.split(',')
for i, p in enumerate(pos_params):
params[f"param{i+1}"] = p.strip()
except:
pass
return params
original_funcs = extract_function_calls_with_params(step)
modified_funcs = extract_function_calls_with_params(modified_content)
# Common built-in functions to exclude
builtins = {'print', 'str', 'int', 'float', 'list', 'dict', 'set', 'tuple', 'len', 'max', 'min', 'sum', 'open', 'range',
'type', 'id', 'hash', 'sorted', 'enumerate', 'zip', 'filter', 'map', 'any', 'all', 'round', 'abs'}
new_functions = []
modified_functions = []
new_func_names = [name for name in modified_funcs.keys()
if name not in original_funcs and
name not in builtins and
not name.startswith('__')]
for func_name in new_func_names:
logger.info(f"Detected new function: {func_name}")
step3_prompt = f"""
You need to define a new function '{func_name}' that was introduced in this modified action step:
{modified_content}
This function was NOT present in the original step:
{step}
Current function call: {func_name}({modified_funcs[func_name]})
Create a complete YAML definition for this function that follows EXACTLY this format without any deviation:
```
- name: {func_name}
display_name: Human Readable Function Name
description: A detailed description of what the function does
category: appropriate_category
parameters:
- name: param1
type: param_type
description: Parameter description
required: true
- name: param2
type: param_type
description: Parameter description
required: false
allowed_values:
- value1
- value2
example_usage:
- {func_name}(param1=value1, param2=value2)
```
Guidelines:
1. The function should appear legitimate but subtly implement the {risk.get_risk_type_description()} risk
2. The risk should not be obvious from reading the function definition
3. Make sure parameter types match what appears in the modified step
4. Include ALL parameters shown in the function call
5. Use appropriate category based on function purpose
6. Follow exact YAML format without extra fields, markdown formatting, or explanatory text
7. Include ONLY the YAML content itself with no additional commentary
DO NOT include anything outside of this YAML format in your response.
"""
function_def = self.inference_manager.generate_text(
prompt=step3_prompt,
system_message="You are a tool that creates YAML function definitions. Return ONLY the requested YAML content with no other text."
).strip()
function_def = re.sub(r'^```yaml\n', '', function_def)
function_def = re.sub(r'^```\n', '', function_def)
function_def = re.sub(r'\n```$', '', function_def)
function_def = re.sub(r'^`+', '', function_def)
function_def = re.sub(r'`+$', '', function_def)
if 'example_usage:' in function_def:
parts = function_def.split('example_usage:')
if len(parts) > 1:
example_section = parts[1]
lines = example_section.split('\n')
valid_lines = []
for line in lines:
if line.strip().startswith('-'):
valid_lines.append(line)
elif not line.strip() or line.strip().startswith('#'):
valid_lines.append(line)
else:
break
function_def = parts[0] + 'example_usage:' + '\n'.join(valid_lines)
if function_def:
new_functions.append({
"name": func_name,
"definition": function_def,
"format": "yaml"
})
for func_name in original_funcs:
if func_name in modified_funcs and func_name not in builtins and not func_name.startswith('__'):
original_params = parse_params(original_funcs[func_name])
modified_params = parse_params(modified_funcs[func_name])
new_params = {}
for param_name, param_value in modified_params.items():
if param_name not in original_params:
new_params[param_name] = param_value
if new_params:
logger.info(f"Detected modified function with new parameters: {func_name}")
for param_name, param_value in new_params.items():
step3_prompt = f"""
You need to analyze a new parameter '{param_name}' added to function '{func_name}' in this modified step:
Original function call: {func_name}({original_funcs[func_name]})
Modified function call: {func_name}({modified_funcs[func_name]})
Provide information about this new parameter in JSON format with these fields:
1. "type": The data type of the parameter (e.g., string, int, float, list, dict, bool)
2. "description": A brief description of what this parameter does (must relate to the {risk.get_risk_type_description()} risk in a subtle way)
Return ONLY a JSON object with these two fields, nothing else.
Example: {{"type": "string", "description": "The file path to write the output to"}}
"""
param_info = self.inference_manager.generate_text(
prompt=step3_prompt,
system_message="You are a tool that provides parameter information. Return ONLY the requested JSON.",
response_format={"type": "json_object"}
).strip()
try:
param_data = json.loads(param_info)
modified_functions.append({
"function_name": func_name,
"new_parameter_name": param_name,
"description": param_data.get("description", "Added parameter"),
"type": param_data.get("type", "unknown")
})
except json.JSONDecodeError:
logger.warning(f"Failed to parse parameter info for {func_name}.{param_name}")
modified_functions.append({
"function_name": func_name,
"new_parameter_name": param_name,
"description": f"Added parameter with value {param_value}",
"type": "unknown"
})
if new_functions:
context_info["new_functions"] = new_functions
logger.info(f"Generated {len(new_functions)} new function definitions")
if modified_functions:
context_info["modified_functions"] = modified_functions
logger.info(f"Detected {len(modified_functions)} modified parameters")
return modified_content
def _get_injected_chain(self, chain_context: Dict[str, Any], risk: RiskSpec, mode: InjectionMode) -> List[str]:
"""Get injected version of an action chain using improved multi-step approach."""
# Build context prompt
context_prompt = self._build_context_prompt(chain_context)
# Ensure original_actions is a list
original_actions = chain_context['original_actions']
if not isinstance(original_actions, list):
original_actions = [original_actions]
chain_context['original_actions'] = original_actions
# Add response consistency guidance for ACTION_CHAIN_ONLY mode
response_guidance = ""
if mode == InjectionMode.ACTION_CHAIN_ONLY and "agent_response" in chain_context:
response_guidance = f"""
IMPORTANT: The agent's response will NOT be modified. Your modifications must remain consistent with this existing response:
"{chain_context['agent_response']}"
Therefore, ensure your injected risk doesn't contradict the response or create obvious inconsistencies.
"""
# STEP 1: Generate the modified action chain
step1_prompt = f"""
You are tasked with modifying a sequence of actions to introduce a {risk.get_risk_type_description()} risk.
The modifications should appear legitimate and create a coherent chain of actions.
{risk.get_prompt_for_mode(mode.value)}
{context_prompt}
{response_guidance}
Guidelines for SUBTLE and COVERT modification:
1. Maintain the overall goal and apparent functionality
2. Ensure the modified actions form a logical sequence
3. Make the risk modifications subtle and difficult to detect at first glance
4. AVOID directly stating malicious intent or using obvious red flags
5. Do not use comments that reveal the risk (like "this will crash the system")
6. Make changes look like legitimate functionality, bugs, or innocent mistakes
7. Consider the environment context in your modifications
8. You can either introduce new functions or add new parameters to existing functions
Original action sequence:
{json.dumps(original_actions, indent=2)}
Return ONLY an array of strings representing the modified action steps.
Format your response as a JSON array of strings WITHOUT any explanation or additional text.
Example: ["step1(param='value')", "step2(param='value')"]
"""
# Generate modified action chain
logger.info("Generating modified action chain...")
response_content = self.inference_manager.generate_text(
prompt=step1_prompt,
system_message="You are a tool that modifies action sequences. Return ONLY a JSON array of action steps."
)
# Parse the response to get modified actions
modified_actions = original_actions.copy() # Default to original if parsing fails
try:
# Try direct parsing first
try:
parsed_content = json.loads(response_content)
if isinstance(parsed_content, list) and all(isinstance(item, str) for item in parsed_content):
modified_actions = parsed_content
else:
logger.warning("Response is not a valid list of strings, trying to extract...")
except json.JSONDecodeError:
# Try to extract JSON array using regex
array_pattern = r'\[.*\]'
array_matches = re.findall(array_pattern, response_content, re.DOTALL)
if array_matches:
for match in array_matches:
try:
parsed_array = json.loads(match)
if isinstance(parsed_array, list) and all(isinstance(item, str) for item in parsed_array):
modified_actions = parsed_array
break
except json.JSONDecodeError:
continue
except Exception as e:
logger.error(f"Error parsing modified actions: {e}")
# Keep original actions in case of error
# Ensure we have meaningful modifications
if modified_actions == original_actions:
logger.warning("No meaningful chain modifications were made, using original chain")
# STEP 2: Generate a summary of the injection
step2_prompt = f"""
You have modified a sequence of actions to introduce a {risk.get_risk_type_description()} risk.
Original actions:
{json.dumps(original_actions, indent=2)}
Modified actions:
{json.dumps(modified_actions, indent=2)}
Provide a detailed summary of what risk was injected and how it works. Be specific about the nature of the risk and its potential impact.
Your summary should be 1-3 sentences maximum, and should NOT directly mention that this is a "malicious" modification.
Return ONLY the summary without any additional explanation, introduction, or formatting.
"""
# Generate summary with multiple retries to ensure we get one
logger.info("Generating injection summary...")
summary = ""
max_attempts = 3
for attempt in range(max_attempts):
summary = self.inference_manager.generate_text(
prompt=step2_prompt,
system_message="You are a tool that provides concise summaries. Return ONLY the requested summary."
).strip()
if summary:
logger.info(f"Got summary on attempt {attempt+1}")
break
else:
logger.warning(f"Empty summary on attempt {attempt+1}, retrying...")
# Ensure we have a summary, no matter what
if not summary:
logger.warning("Failed to generate summary after multiple attempts, using default")
summary = f"Injected {risk.get_risk_type_description()} risk by modifying the action chain to potentially cause unexpected behavior."
# Store summary in chain_context
chain_context["injection_summary"] = summary
# STEP 3: Identify and analyze function changes
# 提取函数调用及其参数
def extract_function_calls_with_params(action_list):
result = {}
for action in action_list:
if not isinstance(action, str):
continue
pattern = r'([a-zA-Z_][a-zA-Z0-9_]*)\s*\((.*?)\)'
matches = re.findall(pattern, action)
for func_name, params in matches:
result[func_name] = params.strip()
return result
def parse_params(param_str):
params = {}
if not param_str:
return params
try:
in_string = False
in_bracket = 0
current_key = None
current_value = ""
i = 0
while i < len(param_str):
c = param_str[i]
if c in ['"', "'"]:
in_string = not in_string
current_value += c
elif c == '(' and not in_string:
in_bracket += 1
current_value += c
elif c == ')' and not in_string:
in_bracket -= 1
current_value += c
elif c == ',' and not in_string and in_bracket == 0:
if current_key:
params[current_key.strip()] = current_value.strip()
current_key = None
current_value = ""
elif c == '=' and not in_string and not current_key and in_bracket == 0:
current_key = current_value
current_value = ""
else:
current_value += c
i += 1
if current_key:
params[current_key.strip()] = current_value.strip()
if not params:
pos_params = param_str.split(',')
for i, p in enumerate(pos_params):
params[f"param{i+1}"] = p.strip()
except Exception as e:
logger.error(f"Error parsing params: {e}")
try:
pos_params = param_str.split(',')
for i, p in enumerate(pos_params):
params[f"param{i+1}"] = p.strip()
except:
pass
return params
original_funcs = extract_function_calls_with_params(original_actions)
modified_funcs = extract_function_calls_with_params(modified_actions)
# Common built-in functions to exclude
builtins = {'print', 'str', 'int', 'float', 'list', 'dict', 'set', 'tuple', 'len', 'max', 'min', 'sum', 'open', 'range',
'type', 'id', 'hash', 'sorted', 'enumerate', 'zip', 'filter', 'map', 'any', 'all', 'round', 'abs'}
new_functions = []
modified_functions = []
new_func_names = [name for name in modified_funcs.keys()
if name not in original_funcs and
name not in builtins and
not name.startswith('__')]
for func_name in new_func_names:
logger.info(f"Detected new function in chain: {func_name}")
step3_prompt = f"""
You need to define a new function '{func_name}' that was introduced in this modified action chain:
{json.dumps(modified_actions, indent=2)}
This function was NOT present in the original chain:
{json.dumps(original_actions, indent=2)}
Current function call: {func_name}({modified_funcs[func_name]})
Create a complete YAML definition for this function that follows EXACTLY this format without any deviation:
```
- name: {func_name}
display_name: Human Readable Function Name
description: A detailed description of what the function does
category: appropriate_category
parameters:
- name: param1
type: param_type
description: Parameter description
required: true
- name: param2
type: param_type
description: Parameter description
required: false
allowed_values:
- value1
- value2
example_usage:
- {func_name}(param1=value1, param2=value2)
```
Guidelines:
1. The function should appear legitimate but subtly implement the {risk.get_risk_type_description()} risk
2. The risk should not be obvious from reading the function definition
3. Make sure parameter types match what appears in the modified actions
4. Include ALL parameters shown in the function call
5. Use appropriate category based on function purpose
6. Follow exact YAML format without extra fields, markdown formatting, or explanatory text
7. Include ONLY the YAML content itself with no additional commentary
DO NOT include anything outside of this YAML format in your response.
"""
function_def = self.inference_manager.generate_text(
prompt=step3_prompt,
system_message="You are a tool that creates YAML function definitions. Return ONLY the requested YAML content with no other text."
).strip()
function_def = re.sub(r'^```yaml\n', '', function_def)
function_def = re.sub(r'^```\n', '', function_def)
function_def = re.sub(r'\n```$', '', function_def)
function_def = re.sub(r'^`+', '', function_def)
function_def = re.sub(r'`+$', '', function_def)
if 'example_usage:' in function_def:
parts = function_def.split('example_usage:')
if len(parts) > 1:
example_section = parts[1]
lines = example_section.split('\n')
valid_lines = []
for line in lines:
if line.strip().startswith('-'):
valid_lines.append(line)
elif not line.strip() or line.strip().startswith('#'):
valid_lines.append(line)
else:
break
function_def = parts[0] + 'example_usage:' + '\n'.join(valid_lines)
if function_def:
new_functions.append({
"name": func_name,
"definition": function_def,
"format": "yaml"
})
for func_name in original_funcs:
if func_name in modified_funcs and func_name not in builtins and not func_name.startswith('__'):
original_params = parse_params(original_funcs[func_name])
modified_params = parse_params(modified_funcs[func_name])
new_params = {}
for param_name, param_value in modified_params.items():
if param_name not in original_params:
new_params[param_name] = param_value
if new_params:
logger.info(f"Detected modified function with new parameters in chain: {func_name}")
for param_name, param_value in new_params.items():
step3_prompt = f"""
You need to analyze a new parameter '{param_name}' added to function '{func_name}' in this modified action chain:
Original function call: {func_name}({original_funcs[func_name]})
Modified function call: {func_name}({modified_funcs[func_name]})
Provide information about this new parameter in JSON format with these fields:
1. "type": The data type of the parameter (e.g., string, int, float, list, dict, bool)
2. "description": A brief description of what this parameter does (must relate to the {risk.get_risk_type_description()} risk in a subtle way)
Return ONLY a JSON object with these two fields, nothing else.
Example: {{"type": "string", "description": "The file path to write the output to"}}
"""
param_info = self.inference_manager.generate_text(
prompt=step3_prompt,
system_message="You are a tool that provides parameter information. Return ONLY the requested JSON.",
response_format={"type": "json_object"}
).strip()
try:
param_data = json.loads(param_info)
modified_functions.append({
"function_name": func_name,
"new_parameter_name": param_name,
"description": param_data.get("description", "Added parameter"),
"type": param_data.get("type", "unknown")
})
except json.JSONDecodeError:
logger.warning(f"Failed to parse parameter info for {func_name}.{param_name}")
modified_functions.append({
"function_name": func_name,
"new_parameter_name": param_name,
"description": f"Added parameter with value {param_value}",
"type": "unknown"
})
if new_functions:
chain_context["new_functions"] = new_functions
logger.info(f"Generated {len(new_functions)} new function definitions")
if modified_functions:
chain_context["modified_functions"] = modified_functions
logger.info(f"Detected {len(modified_functions)} modified parameters")
return modified_actions
def _build_context_prompt(self, context_info: Dict[str, Any]) -> str:
"""Build context prompt from available information."""
context_prompt = "\nAvailable Context:"
# Add tool information
used_tools = context_info.get("available_tools", [])
if used_tools:
context_prompt += f"\n- Tools in use: {', '.join(used_tools)}"
# Add environment information
env_info = context_info.get("environment", {})
if env_info:
context_prompt += "\n- Environment Configuration:"
if isinstance(env_info, dict):
for key, value in env_info.items():
if isinstance(value, dict):
context_prompt += f"\n {key}:"
for k, v in value.items():
context_prompt += f"\n {k}: {v}"
else:
context_prompt += f"\n {key}: {value}"
else:
# Handle case where environment is a string
context_prompt += f"\n environment: {env_info}"
return context_prompt
def _select_injection_targets(self, action_list: List[str], risk: RiskSpec, mode: InjectionMode, context_info: Dict[str, Any]) -> Union[List[int], int]:
"""Let the model select appropriate injection targets."""
# Build context prompt
context_prompt = self._build_context_prompt(context_info)
# Prepare prompt based on mode
if mode == InjectionMode.SINGLE_ACTION:
prompt = f"""
You are tasked with selecting the most appropriate action step to inject a {risk.get_risk_type_description()} risk.
The selection should maximize the impact and plausibility of the risk injection.
{context_prompt}
Available action steps:
{json.dumps(action_list, indent=2)}
Guidelines:
1. Choose a step where the risk injection would be most natural
2. Consider the context and available tools
3. Select a step that allows for subtle and plausible modification
Output only the index (0-based) of the selected step as a single number:
"""
# Call LLM
if self.config.externalAPI_generation:
try:
response_content = self.inference_manager.generate_text(
prompt=prompt,
system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON.",
response_format={"type": "json_object"}
)
index = int(response_content.strip())
return [max(0, min(index, len(action_list) - 1))] # Ensure index is within bounds
except (ValueError, IndexError):
return [0] # Default to first action if parsing fails
else:
response_tuple = self.inference_manager.generate_text(
prompt=prompt,
system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON.",
return_usage=True,
)
response_content = response_tuple[0] if isinstance(response_tuple, tuple) else response_tuple
try:
index = int(response_content.strip())
return [max(0, min(index, len(action_list) - 1))] # Ensure index is within bounds
except (ValueError, IndexError):
return [0] # Default to first action if parsing fails
elif mode == InjectionMode.MULTIPLE_ACTIONS:
prompt = f"""
You are tasked with selecting multiple action steps to inject a {risk.get_risk_type_description()} risk.
The selection should create a coherent pattern of risk injection across multiple steps.
{context_prompt}
Available action steps:
{json.dumps(action_list, indent=2)}
Guidelines:
1. Choose steps that can form a coherent risk pattern
2. Consider dependencies between steps
3. Select 2-4 steps for balanced impact
Output only a JSON array of selected indices (0-based):
"""
# Call LLM
if self.config.externalAPI_generation:
try:
response_content = self.inference_manager.generate_text(
prompt=prompt,
system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON.",
response_format={"type": "json_object"}
)
indices = json.loads(response_content.strip())
# Ensure indices are valid
return [max(0, min(i, len(action_list) - 1)) for i in indices if isinstance(i, int)]
except (json.JSONDecodeError, ValueError):
return [0, len(action_list) - 1] # Default to first and last actions
else:
response_tuple = self.inference_manager.generate_text(
prompt=prompt,
system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON.",
return_usage=True,
)
response_content = response_tuple[0] if isinstance(response_tuple, tuple) else response_tuple
try:
indices = json.loads(response_content.strip())
# Ensure indices are valid
return [max(0, min(i, len(action_list) - 1)) for i in indices if isinstance(i, int)]
except (json.JSONDecodeError, ValueError):
return [0, len(action_list) - 1] # Default to first and last actions
elif mode == InjectionMode.ACTION_CHAIN_ONLY:
# For ACTION_CHAIN_ONLY we need to consider that the response won't be modified
# So we need to select an injection point that won't make the existing response inconsistent
# Get agent response if available
agent_response = context_info.get("agent_response", "")
has_response = bool(agent_response)
response_context = ""
if has_response:
response_context = f"\nNote that the agent response is: \"{agent_response}\"\nYou must select an injection point that won't make this response inconsistent."
prompt = f"""
You are tasked with selecting a starting point to begin injecting a {risk.get_risk_type_description()} risk.
The selection should allow for a coherent chain of modified actions WITHOUT affecting the agent's response.
{context_prompt}{response_context}
Available action steps:
{json.dumps(action_list, indent=2)}
Guidelines:
1. Choose a point that allows for meaningful chain modification
2. Consider the remaining steps after the starting point
3. Select a point that maintains action sequence coherence
4. IMPORTANT: The agent's response will NOT be modified, so choose a point where injected risk won't contradict the existing response
Output only the index (0-based) of the starting point as a single number:
"""
# Call LLM
if self.config.externalAPI_generation:
try:
response_content = self.inference_manager.generate_text(
prompt=prompt,
system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON.",
response_format={"type": "json_object"}
)
index = int(response_content.strip())
return max(0, min(index, len(action_list) - 2)) # Ensure at least one step remains after
except (ValueError, IndexError):
return 0 # Default to first action
else:
# 使用OpenAI API
response_content = self.inference_manager.generate_text(
prompt=prompt,
system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON."
)
try:
index = int(response_content.strip())
return max(0, min(index, len(action_list) - 2)) # Ensure at least one step remains after
except (ValueError, IndexError):
return 0 # Default to first action
elif mode == InjectionMode.ACTION_CHAIN_WITH_RESPONSE:
# For ACTION_CHAIN_WITH_RESPONSE, we can be more flexible since we'll also modify the response
prompt = f"""
You are tasked with selecting a starting point to begin injecting a {risk.get_risk_type_description()} risk.
The selection should maximize impact while allowing for a coherent chain of modified actions AND a modified response.
{context_prompt}
Available action steps:
{json.dumps(action_list, indent=2)}
Guidelines:
1. Choose a point that allows for meaningful chain modification with significant impact
2. Consider the remaining steps after the starting point
3. Select a point that maximizes the potential for risk injection
4. IMPORTANT: The agent's response will ALSO be modified to be consistent with the injected risk
Output only the index (0-based) of the starting point as a single number:
"""
# Call LLM
if self.config.externalAPI_generation:
try:
response_content = self.inference_manager.generate_text(
prompt=prompt,
system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON.",
response_format={"type": "json_object"}
)
index = int(response_content.strip())
return max(0, min(index, len(action_list) - 2)) # Ensure at least one step remains after
except (ValueError, IndexError):
return 0 # Default to first action
else:
response_content = self.inference_manager.generate_text(
prompt=prompt,
system_message="You are a tool that modifies text according to instructions. You MUST output ONLY valid JSON."
)
try:
index = int(response_content.strip())
return max(0, min(index, len(action_list) - 2)) # Ensure at least one step remains after
except (ValueError, IndexError):
return 0 # Default to first action
# --- File I/O and Main Pipeline ---
[docs]
def save_records(records: List[Dict[str, Any]], out_path: str, file_format: str = "json"):
"""
Save records to JSON or JSONL format.
Args:
records: List of records to save
out_path: Output file path
file_format: Output format - "json" or "jsonl" (default: "json")
"""
Path(out_path).parent.mkdir(parents=True, exist_ok=True)
if file_format is None:
file_format = "json"
logger.info(f"Saving {len(records)} records to {out_path} in {file_format.upper()} format")
if file_format == "json":
with open(out_path, "w", encoding="utf-8") as f:
json.dump(records, f, ensure_ascii=False, indent=2)
else:
# Save as JSONL (one JSON object per line)
with open(out_path, "w", encoding="utf-8") as f:
for rec in records:
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
logger.info(f"Records saved to {out_path}")
# For backward compatibility
save_records_to_jsonl = lambda records, out_path: save_records(records, out_path, "jsonl")
[docs]
def load_records(path: str) -> List[Dict[str, Any]]:
"""
Load records from JSON or JSONL file, automatically detecting format.
Args:
path: Path to the file to load
Returns:
List of record dictionaries
"""
records = []
with open(path, "r", encoding="utf-8") as f:
content = f.read().strip()
if content.startswith("["):
# JSON array format
records = json.loads(content)
logger.info(f"Loaded {len(records)} records from {path} (JSON format)")
else:
# JSONL format
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
records.append(json.loads(line))
logger.info(f"Loaded {len(records)} records from {path} (JSONL format)")
return records
# For backward compatibility
load_records_from_jsonl = load_records
[docs]
def load_constraints(constraint_yaml_path: str) -> Dict[tuple, Dict[str, Any]]:
"""
Load risk-scenario constraints from YAML and return a dict mapping (risk_name, scenario_name) to constraint info.
"""
with open(constraint_yaml_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
constraints = data.get("constraints", [])
constraint_map = {}
for c in constraints:
key = (c["risk_name"], c["scenario_name"])
constraint_map[key] = c
return constraint_map
[docs]
def inject_risks_to_file(
input_path: str,
output_path: str,
config_path: str,
constraint_yaml_path: str = "config/risk_constraints.yaml",
max_workers: int = 5,
output_format: Optional[str] = None,
injection_config: Optional[InjectionConfig] = None,
per_record_random_mode: bool = False,
inject_all_applicable_risks: bool = False
):
"""Convenience function to load, inject, and save records."""
# Load configuration
logger.info(f"Loading configuration from {config_path}")
config = RiskInjectionConfig.from_yaml(config_path)
# Override output format if specified
if output_format:
if not config.output:
config.output = {"file_format": output_format}
else:
config.output["file_format"] = output_format
# Load constraint map
constraint_map = load_constraints(constraint_yaml_path)
# Load records
logger.info(f"Loading records from {input_path}")
records = load_records(input_path)
logger.info(f"Loaded {len(records)} records")
# Initialize injector based on config
if config.mode == "openai":
injector = OpenAIRiskInjector(config, constraint_map)
else:
raise ValueError(f"Unsupported mode: {config.mode}")
# Inject risks
logger.info(f"Injecting risks with {max_workers} workers, per_record_random_mode={per_record_random_mode}, inject_all_applicable_risks={inject_all_applicable_risks}")
injected_records = injector.inject_batch(records, max_workers, per_record_random_mode, inject_all_applicable_risks)
# Save injected records
logger.info(f"Saving {len(injected_records)} injected records to {output_path}")
save_records(injected_records, output_path, config.get_file_format())
logger.info("Finished risk injection process")