Skip to main content

Threshold Tuning for DynamoGuard Policies

DynamoGuard now supports setting classification thresholds to adjust policy performance. This tutorial demonstrates how to optimize classification thresholds for DynamoGuard policies using your dataset. We'll cover threshold analysis, visualization, and updating thresholds via API.

Prerequisites

Before starting, ensure you have:

  • DynamoAI Access Token
  • DynamoAI API URL
  • Policy IDs to tune
  • Benchmark Dataset in required format (see Dataset Format section)
  • Desired False Positive Rate (FPR) target

Environment Setup

# Required packages
pip install numpy pandas matplotlib scikit-learn datasets aiohttp requests python-dotenv tqdm gradio scipy

Step 0: Benchmarking Dataset Format

Your dataset should be in CSV format with two columns:

  • Prompt: Text to classify
  • Label: Ground truth ('safe' or 'unsafe')

Example:

Prompt,Label
"Example prompt 1",safe
"Example prompt 2",unsafe

Threshold Analysis Methodology

Below we go through 10 steps to run the benchmarking script, from data loading and benchmark configuration setup, right through to adaptive classification thresholding.

Step 1: Data Collection and Preprocessing

The first step involves setting up data loading and preprocessing functionality:

This below code provides:

  • Support for multiple data formats (CSV, TXT, JSON)
  • HuggingFace dataset integration
  • Column validation
  • Error handling
  • Token-based authentication for private datasets
import os
import json
import pandas as pd
from typing import Tuple, List, Optional
from datasets import load_dataset

def load_data(
path: str,
prompt_column: str = "prompt",
label_column: str = "label",
split: str = "train",
use_auth_token: Optional[str] = None
) -> Tuple[List[str], List[str]]:
"""
Load data from various sources and return prompts and labels.

Args:
path (str): Path to the data file or HuggingFace dataset name
prompt_column (str): Name of the column containing prompts
label_column (str): Name of the column containing labels
split (str): Dataset split to use (for HuggingFace datasets)
use_auth_token (str, optional): HuggingFace token for private datasets

Returns:
Tuple[List[str], List[str]]: Lists of prompts and corresponding labels
"""
# Check if path exists (for local files)
if os.path.exists(path):
file_extension = os.path.splitext(path)[1].lower()

# Handle CSV files
if file_extension == '.csv':
try:
df = pd.read_csv(path)
_validate_columns(df, prompt_column, label_column)
return df[prompt_column].tolist(), df[label_column].tolist()
except Exception as e:
raise ValueError(f"Error reading CSV file: {str(e)}")

# Handle TXT files
elif file_extension == '.txt':
try:
prompts = []
labels = []
with open(path, 'r', encoding='utf-8') as f:
for line in f:
# Try to parse as JSON first
try:
data = json.loads(line.strip())
if isinstance(data, dict):
if prompt_column in data and label_column in data:
prompts.append(str(data[prompt_column]))
labels.append(str(data[label_column]))
continue
except json.JSONDecodeError:
pass

# If not JSON, try tab-separated format
parts = line.strip().split('\t')
if len(parts) >= 2:
prompts.append(parts[0])
labels.append(parts[1])
else:
# Skip lines that don't match expected format
continue

if not prompts or not labels:
raise ValueError("No valid data found in TXT file")
return prompts, labels
except Exception as e:
raise ValueError(f"Error reading TXT file: {str(e)}")

# Handle JSON files
elif file_extension == '.json':
try:
with open(path, 'r', encoding='utf-8') as f:
data = json.load(f)

# Handle different JSON formats
if isinstance(data, list):
# List of dictionaries format
prompts = [str(item.get(prompt_column, "")) for item in data]
labels = [str(item.get(label_column, "")) for item in data]
elif isinstance(data, dict):
# Dictionary format with lists
prompts = [str(p) for p in data.get(prompt_column, [])]
labels = [str(l) for l in data.get(label_column, [])]
else:
raise ValueError("Unsupported JSON format")

if not prompts or not labels:
raise ValueError("Required columns not found in JSON")
return prompts, labels
except Exception as e:
raise ValueError(f"Error reading JSON file: {str(e)}")

else:
raise ValueError(f"Unsupported file extension: {file_extension}")

# Handle HuggingFace datasets
else:
try:
# Prepare dataset loading arguments
dataset_kwargs = {
'split': split,
}

# Only add use_auth_token if it's provided and not None
if use_auth_token:
dataset_kwargs['token'] = use_auth_token

try:
# First try loading with token if provided
dataset = load_dataset(path, **dataset_kwargs)
except Exception as first_error:
if use_auth_token:
# If that fails, try without token
print(f"Failed to load with token, trying without: {str(first_error)}")
dataset_kwargs.pop('token')
dataset = load_dataset(path, **dataset_kwargs)
else:
raise first_error

if prompt_column not in dataset.column_names or label_column not in dataset.column_names:
raise ValueError(f"Required columns '{prompt_column}' and/or '{label_column}' not found in dataset")

return dataset[prompt_column], dataset[label_column]

except Exception as e:
if use_auth_token is None:
raise ValueError(f"Error loading dataset: {str(e)}. If this is a private dataset, please provide a Hugging Face token.")
raise ValueError(f"Error loading dataset: {str(e)}")


def _validate_columns(df: pd.DataFrame, prompt_column: str, label_column: str) -> None:
"""
Validate that required columns exist in the DataFrame.

Args:
df (pd.DataFrame): DataFrame to validate
prompt_column (str): Name of prompt column
label_column (str): Name of label column

Raises:
ValueError: If required columns are missing
"""
missing_columns = []
if prompt_column not in df.columns:
missing_columns.append(prompt_column)
if label_column not in df.columns:
missing_columns.append(label_column)

if missing_columns:
raise ValueError(f"Missing required columns: {', '.join(missing_columns)}")

Step 2: Setting Up Environment and Configuration

Before we begin threshold tuning, we need to set up our environment and configuration. This includes loading environment variables and establishing our benchmark configuration class.

  1. First, create a .env file in your project root directory with your API tokens:
STAGING_API_TOKEN=your_staging_token_here
PROD_API_TOKEN=your_prod_token_here
  1. Copy and paste the following code to set up the configuration:

# Utility imports
import os
from pathlib import Path
from typing import Dict, List
from dotenv import load_dotenv
from dataclasses import dataclass
load_dotenv()

# Configuration
LOGGING_VERBOSE = True
RESULT_VERBOSE = True
CONFIG_VERBOSE = False

# Load environment variables from .env file
env_path = Path('./.env')
if env_path.exists():
load_dotenv(env_path)
else:
raise FileNotFoundError(
"Please create an .env file with required API tokens." +
"You need to define the STAGING_API_TOKEN or PROD_API_TOKEN, or both.")


@dataclass
class BenchmarkConfig:
"""Configuration for the benchmark"""
LOGGING_VERBOSE: bool = True
RESULT_VERBOSE: bool = True
CONFIG_VERBOSE: bool = True

# API Configuration
API_ENDPOINTS: Dict[str, str] = None

def __post_init__(self):
self.API_ENDPOINTS = {
'prod_analyze': 'https://api.dynamo.ai/v1/moderation/analyze/',
'prod_chat': 'https://api.dynamo.ai/v1/moderation/chat/session_id',
'staging_analyze': 'https://api.staging.dynamo.ai/v1/moderation/analyze/',
'staging_chat': 'https://api.staging.dynamo.ai/v1/moderation/model/66f6e19a97fd00bf8bfa055b/chat/session_id'
}

@property
def API_ENDPOINT_OPTIONS(self) -> List[str]:
return list(self.API_ENDPOINTS.values())

def get_api_token(self, endpoint: str) -> str:
"""Get the appropriate API token based on the endpoint"""
if 'staging' in endpoint:
print("Staging token set")
token = os.getenv("STAGING_API_TOKEN")
if not token:
raise ValueError("STAGING_API_TOKEN not found in environment variables")
return token
else:
token = os.getenv("PROD_API_TOKEN")
print("Prod token set")
if not token:
raise ValueError("PROD_API_TOKEN not found in environment variables")
return token

def validate_endpoint(self, endpoint: str) -> bool:
"""Validate if the endpoint is supported"""
return endpoint in self.API_ENDPOINT_OPTIONS

Step 3: Cache Management Setup

To optimize performance and avoid redundant API calls, we'll implement a caching system. This system will store and retrieve benchmark results based on policy IDs and dataset combinations. The CacheManager class provides:

  • Automatic cache directory creation
  • Unique cache file generation based on policy IDs and dataset
  • Methods to save and load benchmark results
  • Cache validation and timestamp tracking

Key features:

  1. Cache files are stored in ./benchmark_results/cache by default
  2. File names are generated using MD5 hashes of policy IDs and dataset names
  3. Cache entries include:
    • Policy IDs
    • Benchmark dataset name
    • Results and logits
    • Timestamp of cache creation
  4. Automatic validation of cache entries when loading

This caching system will significantly speed up subsequent runs of the benchmark with the same configuration.

Copy and paste the following code to set up the cache management system:

import os
import hashlib
import pickle
from datetime import datetime
from typing import List, Dict, Optional, Tuple

class CacheManager:
"""Handles all caching operations"""
def __init__(self, cache_dir: str = './benchmark_results/cache'):
self.cache_dir = cache_dir
os.makedirs(self.cache_dir, exist_ok=True)

def _get_cache_filename(self, policy_ids: List[str], benchmark_dataset: str) -> str:
policy_str = '_'.join(sorted(policy_ids))
cache_key = f"{policy_str}_{benchmark_dataset}"
hash_key = hashlib.md5(cache_key.encode()).hexdigest()
return os.path.join(self.cache_dir, f"logits_cache_{hash_key}.pkl")

def save_to_cache(self, policy_ids: List[str], benchmark_dataset: str,
all_results: List, all_logits: Dict, labels: List, verbose: bool = True):
cache_file = self._get_cache_filename(policy_ids, benchmark_dataset)
cache_data = {
'policy_ids': policy_ids,
'benchmark_dataset': benchmark_dataset,
'results': all_results,
'logits': all_logits,
'labels': labels,
'timestamp': datetime.now().isoformat()
}
with open(cache_file, 'wb') as f:
pickle.dump(cache_data, f)
if verbose:
print(f"Cached results saved to: {cache_file}")

def load_from_cache(self, policy_ids: List[str], benchmark_dataset: str,
verbose: bool = True) -> Optional[Tuple[List, Dict]]:
cache_file = self._get_cache_filename(policy_ids, benchmark_dataset)
if os.path.exists(cache_file):
with open(cache_file, 'rb') as f:
cache_data = pickle.load(f)
if (sorted(cache_data['policy_ids']) == sorted(policy_ids) and
cache_data['benchmark_dataset'] == benchmark_dataset):
if verbose:
print(f"Loading cached results from: {cache_file}")
print(f"Cache timestamp: {cache_data['timestamp']}")
return cache_data['results'], cache_data['logits'], cache_data['labels']
return None

Step 4: API Client Implementation

The API Client handles all interactions with the moderation API, including request handling, batching, and retry logic. This component is crucial for reliable data collection and efficient API usage.

Copy and paste the following code to implement the API client:

import asyncio
import aiohttp
from typing import Dict, List

class APIClient:
"""Handles all API interactions"""
def __init__(self, api_url: str, headers: Dict[str, str], policy_id: List[str]):
self.api_url = api_url
self.headers = headers
self.policy_id = policy_id

async def send_request(self, prompt, label, policy_ids):
"""Send a single request with multiple policy IDs"""
label = label.lower()

async with aiohttp.ClientSession() as session:
if 'chat/session_id' in self.api_url:
json_data = {
'messages': [{'role': 'user', 'content': prompt}],
"modelId": self.model_id
}
else:
json_data = {
'messages': [{'role': 'user', 'content': prompt}],
"textType": "MODEL_INPUT",
"policyIds": policy_ids, # Now sending multiple policy IDs
"return_logits": True
}

try:
await asyncio.sleep(1)
async with session.post(self.api_url, headers=self.headers, json=json_data, ssl=False) as response:
response_json = await response.json()
print(response_json)

# Process results for each policy
results = {}
for policy in response_json.get('appliedPolicies', []):
policy_id = policy['policy']['_id']
guard_score = policy['outputs'].get('guard_score', 0)
logits = [1 - guard_score, guard_score]
pred = 'unsafe' if logits[1] > logits[0] else 'safe'
results[policy_id] = (pred, logits)

return prompt, label, results

except Exception as e:
print(f"Request failed: {e}")
return prompt, label, None

async def send_batched_requests(self, prompts, labels, batch_size, verbose_logging=False):
all_results = []
all_logits = {policy_id: [] for policy_id in self.policy_id}
final_labels = [] # Add this to track valid labels
i = 0

async def send_request_with_retry(prompt, label, policy_ids, max_retries=3, base_delay=1):
for attempt in range(max_retries):
try:
await asyncio.sleep(base_delay * (attempt + 1))
result = await self.send_request(prompt, label, policy_ids)
return result
except Exception as e:
if attempt == max_retries - 1:
print(f"Failed after {max_retries} attempts: {e}")
return prompt, label, None
print(f"Attempt {attempt + 1} failed, retrying...")

while i < len(prompts):
batch_prompts = []
batch_labels = []

while i < len(prompts) and len(batch_prompts) < batch_size:
if labels[i] is not None:
batch_prompts.append(prompts[i])
batch_labels.append(labels[i])
i += 1

# Dictionary to store results for current batch
batch_results_by_prompt = {}
failed_prompts = set()

# Send one request per prompt with all policy IDs
tasks = [
send_request_with_retry(prompt, label, self.policy_id)
for prompt, label in zip(batch_prompts, batch_labels)
]

async with asyncio.Semaphore(5):
batch_results = await asyncio.gather(*tasks)

successful_requests = 0
failed_requests = 0

for result in batch_results:
prompt, label, policy_results = result

if policy_results is None:
failed_requests += 1
failed_prompts.add(prompt)
continue

successful_requests += 1
final_labels.append(label) # Only add labels for successful requests

# Process results for each policy
if prompt not in batch_results_by_prompt:
batch_results_by_prompt[prompt] = {
'label': label,
'predictions': [],
'any_unsafe': False
}

# Store results for each policy
for policy_id, (pred, logits) in policy_results.items():
all_logits[policy_id].append(logits)
batch_results_by_prompt[prompt]['predictions'].append(pred)
if pred == 'unsafe':
batch_results_by_prompt[prompt]['any_unsafe'] = True

if verbose_logging:
print(f"Batch results:")
print(f"Successful requests: {successful_requests}")
print(f"Failed requests: {failed_requests}")

# Convert combined results to final format
for prompt, result in batch_results_by_prompt.items():
if prompt not in failed_prompts and len(result['predictions']) == len(self.policy_id):
final_pred = 'unsafe' if result['any_unsafe'] else 'safe'
all_results.append((prompt, result['label'], final_pred))

if i < len(prompts):
await asyncio.sleep(0.5)

# Process valid results
valid_indices = []
for idx, (prompt, _, _) in enumerate(all_results):
if all(len(all_logits[policy_id]) > idx for policy_id in self.policy_id):
valid_indices.append(idx)

all_results = [all_results[i] for i in valid_indices]
final_labels = [final_labels[i] for i in valid_indices] # Filter labels using same indices
for policy_id in all_logits:
all_logits[policy_id] = [all_logits[policy_id][i] for i in valid_indices]

if verbose_logging:
print(f"\nFinal Results Summary:")
print(f"Total prompts processed: {len(prompts)}")
print(f"Total results collected: {len(all_results)}")
print(f"Total valid labels: {len(final_labels)}")
print(f"Prompts excluded due to failures: {len(prompts) - len(all_results)}")
for policy_id in all_logits:
print(f"Policy {policy_id} logits collected: {len(all_logits[policy_id])}")
return all_results, all_logits, final_labels # Return the filtered labels

Step 5: Threshold Analysis Implementation

The Threshold Analyzer is responsible for analyzing and optimizing classification thresholds for multiple policies. It supports both ROC curve-based and gradient-based optimization approaches. The ThresholdAnalyzer provides:

  1. Two optimization methods:
    • ROC curve-based analysis
    • Gradient-based optimization
  2. Threshold updating capabilities via API
  3. Optimal threshold selection based on desired false positive rate (FPR)
  4. Comprehensive metrics tracking including:
    • Accuracy
    • F1 Score
    • Confusion matrix components:
      • True Negatives (TN)
      • False Positives (FP)
      • False Negatives (FN)
      • True Positives (TP) 5Detailed logging of the evaluation process

Key features:

  • Supports multiple policy optimization
  • Handles API interactions for threshold updates
  • Provides detailed metrics for performance evaluation
  • Allows for different optimization strategies
  • Handles multiple policy evaluation
  • Provides detailed debugging information
  • Robust error handling with traceback
  • Clear metric reporting
  • Support for both dictionary and array-based thresholds

This component is crucial.

Copy and paste the following code to implement the threshold analyzer:

import numpy as np
from sklearn.metrics import roc_curve, confusion_matrix, f1_score, accuracy_score
from scipy.optimize import minimize
import aiohttp

class ThresholdAnalyzer:


async def update_thresholds(self, thresholds: dict):
"""Update thresholds for policies on the platform"""
# Get base URL based on endpoint
base_url = self.api_url.split('/analyze')[0] if '/analyze' in self.api_url else self.api_url.split('/chat')[0]

async with aiohttp.ClientSession() as session:
results = []
for policy_id, threshold in thresholds.items():
update_url = f"{base_url}/policy/{policy_id}/update-threshold"
json_data = {
"classification_threshold": float(threshold)
}

try:
async with session.put(
update_url,
headers=self.headers,
json=json_data,
ssl=False
) as response:
if response.status == 200:
print(f"Successfully updated threshold for policy {policy_id} to {threshold}")
results.append({
"policy_id": policy_id,
"status": "success",
"threshold": threshold
})
else:
error_text = await response.text()
print(f"Failed to update threshold for policy {policy_id}: {error_text}")
results.append({
"policy_id": policy_id,
"status": "failed",
"error": error_text
})
except Exception as e:
print(f"Error updating threshold for policy {policy_id}: {str(e)}")
results.append({
"policy_id": policy_id,
"status": "error",
"error": str(e)
})

return results

def analyze_thresholds(self, all_logits, true_labels, desired_fpr, mode="uniform"):
if mode == "gradient":
return self.gradient_analyze_thresholds(all_logits, true_labels, desired_fpr)
else:
return self.roc_analyze_thresholds(all_logits, true_labels, desired_fpr)

def find_optimal_thresholds(self, metrics, desired_fpr=None):
"""Find optimal threshold combination based on overall FPR"""
if desired_fpr is not None:
idx = np.argmin(np.abs(np.array(metrics['fpr']) - desired_fpr))
else:
idx = np.argmax(metrics['f1'])

optimal_thresholds = metrics['threshold_combinations'][idx]

return {
'thresholds': {str(k): float(v) for k, v in optimal_thresholds.items()},
'accuracy': float(metrics['accuracy'][idx]),
'fpr': float(metrics['fpr'][idx]),
'fnr': float(metrics['fnr'][idx]),
'f1': float(metrics['f1'][idx])
}

def roc_analyze_thresholds(self, all_logits, true_labels, desired_fpr):

y_true = np.array([1 if label == 'unsafe' else 0 for label in true_labels])
logits_per_policy = {policy_id: np.array(logits)[:, 1] for policy_id, logits in all_logits.items()}

print(f"\nAnalyzing thresholds for desired FPR: {desired_fpr}")
print(f"Label distribution: {np.bincount(y_true)}") # Show class distribution

def select_threshold(proba, target, target_fpr=0.05):
fpr, tpr, thresholds = roc_curve(target, proba, pos_label=1)

# Debug prints
print("\nROC curve details:")
print(f"Target FPR: {target_fpr}")
print(f"Available FPRs: {fpr}")
print(f"Available thresholds: {thresholds}")

# Find closest FPR to target
idx = np.argmin(np.abs(fpr - target_fpr))
print(f"Chosen index: {idx}")
print(f"Resulting FPR: {fpr[idx]}")
print(f"Chosen threshold: {thresholds[idx]}")

# Maybe we should interpolate between two closest points
if fpr[idx] < target_fpr and idx + 1 < len(fpr):
# Find the two FPR points that bracket our target
fpr1, fpr2 = fpr[idx], fpr[idx + 1]
thresh1, thresh2 = thresholds[idx], thresholds[idx + 1]

# Interpolate to get the exact threshold we want
if fpr2 - fpr1 > 0: # Avoid division by zero
alpha = (target_fpr - fpr1) / (fpr2 - fpr1)
interpolated_threshold = thresh1 + alpha * (thresh2 - thresh1)
print(f"Interpolated threshold: {interpolated_threshold}")
return interpolated_threshold

return thresholds[idx]

metrics = {
'threshold_combinations': [],
'accuracy': [],
'fpr': [],
'fnr': [],
'f1': []
}

thresholds = {}
for policy_id, logits in logits_per_policy.items():
print(f"\nAnalyzing policy: {policy_id}")
print(f"Logits range: [{np.min(logits):.6f}, {np.max(logits):.6f}]")
print(f"Logits distribution: {np.percentile(logits, [0, 25, 50, 75, 100])}")

threshold = select_threshold(
logits,
y_true,
target_fpr=desired_fpr # Use the desired FPR directly
)
thresholds[policy_id] = threshold

# Verify predictions with these thresholds
accuracy, fpr, fnr, f1 = self.evaluate_thresholds(thresholds, logits_per_policy, y_true)
print(f"Verification - FPR: {fpr:.4f}, FNR: {fnr:.4f}")

metrics['threshold_combinations'].append(thresholds)
metrics['accuracy'].append(accuracy)
metrics['fpr'].append(fpr)
metrics['fnr'].append(fnr)
metrics['f1'].append(f1)
return metrics

def gradient_analyze_thresholds(self, all_logits, true_labels, desired_fpr, policy_id):
"""Analyze thresholds using gradient-based optimization"""
y_true = np.array([1 if label == 'unsafe' else 0 for label in true_labels])
logits_per_policy = {policy_id: np.array(logits) for policy_id, logits in all_logits.items()}

metrics = {
'threshold_combinations': [],
'accuracy': [],
'fpr': [],
'fnr': [],
'f1': []
}

def objective_function(thresholds):
predictions_per_policy = []
for idx, policy_id in enumerate(logits_per_policy.keys()):
logits = logits_per_policy[policy_id]
pred = (logits[:, 1] > thresholds[idx]).astype(int)
predictions_per_policy.append(pred)

y_pred = np.any(predictions_per_policy, axis=0).astype(int)
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
fpr = fp / (fp + tn) if (fp + tn) > 0 else 0

# Objective: minimize distance to desired FPR while maximizing F1
fpr_penalty = 100 * abs(fpr - desired_fpr)
f1 = f1_score(y_true, y_pred, pos_label=1)

if self.config.LOGGING_VERBOSE:
print(f"Thresholds: {thresholds}, FPR: {fpr:.4f}, F1: {f1:.4f}")

return fpr_penalty - f1 # Minimize this

best_result = None
best_score = float('inf')

for _ in range(5):
# Random initialization around median
x0_perturbed = [
np.median(logits_per_policy[policy_id][:, 1]) + np.random.normal(0, 0.1)
for policy_id in self.policy_id]

# Optimization bounds
bounds = [(logits[:, 1].min(), logits[:, 1].max())
for logits in logits_per_policy.values()]

# Run optimization
result = minimize(
objective_function,
x0_perturbed,
method='L-BFGS-B',
bounds=bounds,
options={'maxiter': 200}
)
# Validate result
score = objective_function(result.x)
if score < best_score:
best_score = score
best_result = result

# Store results for plotting
optimal_thresholds = best_result.x
if np.all(optimal_thresholds == 0.0):
print("Warning: All thresholds are 0.0, this might indicate an optimization problem")

# Debug information
print(f"Optimization result: {best_result.message}")
print(f"Final objective value: {best_score}")

# Verify predictions with these thresholds
accuracy, fpr, fnr, f1 = self.evaluate_thresholds(optimal_thresholds, logits_per_policy, y_true)
print(f"Verification - FPR: {fpr:.4f}, FNR: {fnr:.4f}")

metrics['threshold_combinations'].append(optimal_thresholds)
metrics['accuracy'].append(accuracy)
metrics['fpr'].append(fpr)
metrics['fnr'].append(fnr)
metrics['f1'].append(f1)

return metrics

def evaluate_thresholds(self, thresholds, logits_per_policy, y_true):
print("\nStarting threshold evaluation...")
print(f"Thresholds: {thresholds}")
print(f"Logits per policy keys: {list(logits_per_policy.keys())}")
print(f"Y true shape: {y_true.shape}")

predictions_per_policy = []

try:
for idx, policy_id in enumerate(logits_per_policy.keys()):
print(f"\nProcessing policy {policy_id}")
logits = logits_per_policy[policy_id]
threshold = thresholds[policy_id] if isinstance(thresholds, dict) else thresholds[idx]

print(f"Logits shape: {logits.shape}")
print(f"Threshold: {threshold}")
print(f"Logits range: [{logits.min():.4f}, {logits.max():.4f}]")

pred = (logits > threshold).astype(int)
print(f"Predictions made: {len(pred)}")
predictions_per_policy.append(pred)

print("\nCombining predictions...")
y_pred = np.any(predictions_per_policy, axis=0).astype(int)
print(f"Combined predictions shape: {y_pred.shape}")

print("\nCalculating metrics...")
accuracy = accuracy_score(y_true, y_pred)
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()

print(f"\nConfusion Matrix:")
print(f"TN: {tn}, FP: {fp}, FN: {fn}, TP: {tp}")

fpr = fp / (fp + tn) if (fp + tn) > 0 else 0
fnr = fn / (fn + tp) if (fn + tp) > 0 else 0
f1 = f1_score(y_true, y_pred, pos_label=1)

print(f"\nFinal metrics:")
print(f"Accuracy: {accuracy:.4f}")
print(f"FPR: {fpr:.4f}")
print(f"FNR: {fnr:.4f}")
print(f"F1: {f1:.4f}")

return accuracy, fpr, fnr, f1

except Exception as e:
print(f"Error in evaluate_thresholds: {str(e)}")
import traceback
traceback.print_exc()
return 0, 0, 0, 0

Step 6: Visualization Implementation

The Visualizer class handles the creation and saving of plots to analyze threshold performance. It supports both single and multiple policy visualization with detailed metrics.

Copy and paste the following code to implement the visualizer:

import os
import matplotlib.pyplot as plt
import numpy as np

class Visualizer:
"""Handles all visualization tasks"""
def __init__(self, save_dir: str = './benchmark_results', saving_title='results'):
self.save_dir = save_dir
os.makedirs(save_dir, exist_ok=True)
self.saving_title = saving_title

def plot_threshold_analysis(self, policy_id, metrics, desired_fpr=None):
"""Plot threshold analysis results with multiple visualization options"""
print(f"Number of policies: {len(policy_id)}") # Debug print

if len(policy_id) == 1:
print("Using single policy plot")
self._plot_single_policy(metrics, desired_fpr)
return "single"
else:
print("Using multiple policies plot")
self._plot_multiple_policies(metrics, desired_fpr)
return "multiple"

def _plot_single_policy(self, metrics, desired_fpr=None):
"""Plot for single policy case"""
plt.figure(figsize=(12, 6))

# Get the threshold values from threshold_combinations
thresholds = [list(tc.values())[0] for tc in metrics['threshold_combinations']]
plt.bar('Accuracy', metrics['accuracy'], label='Accuracy')
plt.bar('FPR', metrics['fpr'], label='False Positive Rate')
plt.bar('FNR', metrics['fnr'], label='False Negative Rate')
plt.bar('F1', metrics['f1'], label='F1 Score')

if desired_fpr is not None:
plt.axhline(y=desired_fpr, color='r', linestyle='--',
label=f'Desired FPR: {desired_fpr}')

plt.xlabel('Metrics')
plt.ylabel('Score')
plt.title('Metrics vs Threshold')
plt.legend()
plt.grid(True)

self._save_plot("single")

def _plot_multiple_policies(self, metrics, desired_fpr=None):
"""Plot for multiple policies using mean threshold and subplots"""
# Create figure with two subplots side by side
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

# Sort by FPR for better visualization
sort_idx = np.argsort(metrics['fpr'])
fpr = np.array(metrics['fpr'])[sort_idx]
accuracy = np.array(metrics['accuracy'])[sort_idx]
fnr = np.array(metrics['fnr'])[sort_idx]
f1 = np.array(metrics['f1'])[sort_idx]
threshold_combinations = np.array(metrics['threshold_combinations'], dtype=object)[sort_idx]

# Plot 1: Metrics vs Mean Threshold
mean_thresholds = [np.mean(list(tc.values())) for tc in threshold_combinations]

ax1.bar('Accuracy', accuracy, label='Accuracy')
ax1.bar('FPR', fpr, label='False Positive Rate')
ax1.bar('FNR', fnr, label='False Negative Rate')
ax1.bar('F1', f1, label='F1 Score')

if desired_fpr is not None:
ax1.axhline(y=desired_fpr, color='r', linestyle='--',
label=f'Desired FPR: {desired_fpr}')

# Plot horizontal lines for each threshold combination
colors = plt.get_cmap('viridis')(np.linspace(0, 1, len(threshold_combinations)))
for idx, (tc, color) in enumerate(zip(threshold_combinations, colors)):
for policy_id, threshold in tc.items():
ax1.axhline(y=threshold, color=color, linestyle=':',
alpha=0.3, label=f'Threshold {idx} - {policy_id}: {threshold:.4f}')

ax1.set_xlabel('Mean Threshold')
ax1.set_ylabel('Score')
ax1.set_title('Metrics vs Mean Threshold')
ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax1.grid(True)

# Plot 2: Individual Policy Thresholds at Optimal Point
# Find optimal point (closest to desired FPR)
if desired_fpr is not None:
optimal_idx = np.argmin(np.abs(fpr - desired_fpr))
else:
optimal_idx = np.argmax(f1)

optimal_thresholds = threshold_combinations[optimal_idx] # This is a dictionary

# Convert dictionary to lists for plotting
policy_ids = list(optimal_thresholds.keys())
threshold_values = list(optimal_thresholds.values())

ax2.bar(policy_ids, threshold_values)
ax2.set_xlabel('Policy ID')
ax2.set_ylabel('Optimal Threshold')
ax2.set_title(f'Optimal Thresholds per Policy\n(FPR: {fpr[optimal_idx]:.3f})')
plt.xticks(rotation=45)
ax2.grid(True)

plt.tight_layout()
self._save_plot("multiple")

def _save_plot(self, plot_type):
"""Save the plot with appropriate naming"""
# Ensure the directory exists
os.makedirs('benchmark_results', exist_ok=True)

# Create the full path
save_title = self.saving_title.replace("/", "_")
plot_path = os.path.join(
'benchmark_results',
f'threshold_analysis_{plot_type}_{save_title}.png'
)

# Save the plot
print(f"Saving plot to: {plot_path}")
plt.savefig(plot_path)
plt.close()

# Verify the save
if os.path.exists(plot_path):
print(f"Plot saved successfully to: {plot_path}")
else:
print(f"Failed to save plot to: {plot_path}")

Step 7: Main Benchmarking Class Implementation

The DGuardThresholdBenchmark class serves as the main orchestrator, integrating all components for threshold optimization and evaluation. Copy and paste the following code to implement the main benchmarking class:

import os
from typing import List, Dict, Tuple

class DGuardThresholdBenchmark:
def __init__(
self,
api_url: str,
policy_id: str | List[str],
config: BenchmarkConfig,
saving_title: str = "results",
mode: str = "uniform"
):
if not config.validate_endpoint(api_url):
raise ValueError(f"Invalid API endpoint: {api_url}")

# Initialize components
self.cache_manager = CacheManager()
self.api_client = APIClient(
api_url=api_url,
headers={
'Content-Type': 'application/json',
'Authorization': f'Bearer {config.get_api_token(api_url)}',
},
policy_id=policy_id
)

self.analyzer = ThresholdAnalyzer()
self.visualizer = Visualizer(saving_title=saving_title)

# Create benchmark_results directory if it doesn't exist
os.makedirs('./benchmark_results', exist_ok=True)
# Add cache directory
self.cache_dir = './benchmark_results/cache'
os.makedirs(self.cache_dir, exist_ok=True)

self.api_url = api_url
self.policy_id = [policy_id] if isinstance(policy_id, str) else policy_id
self.model_id = policy_id
self.config = config
self.mode = mode


async def run_evaluation(self, benchmark_dataset: str, prompts: List[str],
labels: List[str], batch_size: int = 50) -> Tuple[List, Dict]:
"""Run evaluation with caching"""
# Try to load from cache first
cached_data = self.cache_manager.load_from_cache(self.policy_id, benchmark_dataset)
if cached_data is not None:
return cached_data

# If not in cache, run evaluation
results, logits, labels = await self.api_client.send_batched_requests(
prompts, labels, batch_size, self.config.LOGGING_VERBOSE)

# Save to cache
self.cache_manager.save_to_cache(self.policy_id, benchmark_dataset, results, logits, labels)

return results, logits, labels

The DGuardThresholdBenchmark class provides:

  1. Integration of all components:

    • Cache Manager
    • API Client
    • Threshold Analyzer
    • Metrics Evaluator
    • Visualizer
  2. Configuration management:

    • API endpoint validation
    • Policy ID handling
    • Directory setup
    • Results storage
  3. Key features:

    • Cached evaluation results
    • Batch processing support
    • Flexible policy ID handling (single or multiple)
    • Comprehensive results storage
    • Automatic directory management
  4. Main functionalities:

    • Component initialization
    • Cache directory setup
    • Results management
    • Evaluation execution with caching

This class serves as the central coordinator for the threshold tuning process, managing all components and their interactions.

Step 8: Benchmarking Function

The below function ties all the above functionality into a single function to perform threshold tuning.

async def run_benchmark_with_threshold(
# Setup parameters
policy_id_dict: dict,
api_endpoint='staging_analyze', # 'prod_analyze',
benchmark_dataset="dynamofl/benchmark-experian-security-legal-compliance-input",
desired_fpr=0.05, # Optional
mode = "uniform"
):

policy_id = list(policy_id_dict.values())
policy_names = list(policy_id_dict.keys())

# Usage:
config = BenchmarkConfig()
if config.CONFIG_VERBOSE:
print("Available API endpoints:", config.API_ENDPOINT_OPTIONS)

# Initialize benchmark with config
benchmark = DGuardThresholdBenchmark(
api_url=config.API_ENDPOINTS[api_endpoint],
policy_id=policy_id,
config=config,
saving_title=benchmark_dataset
)

# Load dataset
prompts, labels = load_data(
path=benchmark_dataset,
prompt_column="prompt",
label_column="label",
split="train"
)

# Run evaluation with caching
results, logits, labels = await benchmark.run_evaluation(
benchmark_dataset=benchmark_dataset,
prompts=prompts,
labels=labels,
batch_size=20
)

# Analyze thresholds
metrics = benchmark.analyzer.analyze_thresholds(logits, labels, desired_fpr, mode)
print("Optimal threshold results:", metrics['threshold_combinations'])

# Plot results
benchmark.visualizer.plot_threshold_analysis(policy_id, metrics, desired_fpr)

dataset_name = benchmark_dataset.replace("/", "_")
model_ids = "_".join(policy_id)

plot_path = f'./benchmark_results/threshold_analysis_multiple_{dataset_name}_ids_{model_ids}.png'
return plot_path, str(metrics)

Step 9: Example usage

Here's how to use the threshold tuning system with a complete example. This implementation shows how to run a benchmark with threshold optimization for multiple policies. Then run the above functions which your endpoint choice, policy dictionary and benchmark path (either huggingface path or local path for a .txt or json file).

Key features:

  • Supports multiple policies
  • Configurable API endpoint
  • Custom dataset support
  • Desired FPR targeting
  • Automatic result saving
  • Visualization generation

Make sure to:

  1. Have valid policy IDs
  2. Configure the correct API endpoint
  3. Specify the desired false positive rate
  4. Have access to the benchmark dataset

The function returns:

  • Path to the generated visualization plot
  • Optimal threshold configuration as a string
# Define your policies
policy_dict = {
"Policy1": "policy_id_1",
"Policy2": "policy_id_2"
}

# Run the benchmark
plot_path, optimal_results = await run_benchmark_with_threshold(
policy_id_dict=policy_dict,
api_endpoint='staging_analyze',
benchmark_dataset="dynamofl/benchmark-experian-security-legal-compliance-input",
desired_fpr=0.05
)

print(f"Results saved to: {plot_path}")
print(f"Optimal thresholds: {optimal_results}")

Metrics Explanation

  • FPR (False Positive Rate): Ratio of false positives to total negatives

    • Key metric for controlling false alarms
    • Target value typically set based on business requirements
    • Lower values mean stricter classification
  • FNR (False Negative Rate): Ratio of false negatives to total positives

    • Indicates missed unsafe content
    • Trade-off with FPR
    • Important for safety considerations
  • Accuracy: Overall classification accuracy

    • General measure of performance
    • May be misleading with imbalanced datasets
    • Consider alongside other metrics
  • F1 Score: Harmonic mean of precision and recall

    • Balanced measure of performance
    • Particularly useful for imbalanced datasets
    • Helps optimize threshold selection

Best Practices

  1. ROC Analysis

    • Start with ROC-based analysis for single policies
    • Use interpolation for precise threshold targeting
    • Validate results visually
  2. Multi-Policy Optimization

    • Use gradient-based optimization for multiple policies
    • Consider policy interactions
    • Balance individual and ensemble performance
  3. FPR Targeting

    • Set realistic FPR targets based on requirements
    • Consider business impact of false positives
    • Monitor FPR after deployment
  4. Validation

    • Use separate validation dataset
    • Monitor performance over time
    • Regularly re-evaluate thresholds
  5. Caching

    • Utilize caching for efficiency
    • Maintain cache hygiene
    • Update cache with new data

Troubleshooting

Common issues and solutions:

  1. API Connection

    • Check API tokens
    • Verify endpoint URLs
    • Handle rate limiting
  2. Cache Management

    • Clear stale cache
    • Verify cache integrity
    • Handle cache misses
  3. Threshold Optimization

    • Check convergence
    • Validate interpolation
    • Monitor optimization bounds
  4. Multi-Policy Issues

    • Check policy interactions
    • Validate ensemble behavior
    • Monitor individual policy performance

Copyright 2024 DynamoAI, Inc. All rights reserved.

This software is provided "as is", without warranty of any kind, express or implied. Use of this software is governed by the Terms of Use between you or your employer and DynamoAI.