Skip to main content

Walkthrough: Summarization - Privacy Aware Fine Tuning

In this tutorial, we will showcase using DynamoFL's ML SDK to:

  • Train LLMs with differential privacy.
  • Evaluate trained LLMs with privacy attacks.

Prerequisites:

  • DynamoFL API token
  • Having installed dynamofl-privacy sdk
  • Having installed the additional dependencies for this tutorial: pip install -r requirements.txt

Copyright 2023 DynamoFL, Inc. All rights reserved.
All use of the software is subject to the license terms and conditions in the Terms of Use between you or your employer and DynamoFL. Without limiting the terms and conditions of the Terms of Use, you may not:

  • Use the software in any manner that isn't directly tied to your use of the DynamoFL platform, including any use that attempts to remove or circumvent the dependency of the software on the DynamoFL platform;
  • Reverse engineer or otherwise attempt to discover underlying structure, ideas, or algorithms of the software;
  • Copy, distribute, publish, or otherwise make available the software.

You may make minor edits to the software solely as necessary to train your models on the DynamoFL platform.

Setup

At this stage, we assume you have already installed the dynamofl-privacy sdk and have an API token.

!pip config unset global.index-url # reset the pip index so regular packages can be installed as normal
!pip install -r requirements.txt
'''
DynamoFL SDK imports
'''
from dynamofl.privacy import (
get_dp_seq2seq_trainer,
PrivacyArguments,
)
'''
DynamoFL tutorial helper functions
'''
from hf_utils import (
ModelArguments,
DPArguments,
setup_logging,
seq2seq_model_loading,
prepare_data_for_training,
Seq2SeqDataTrainingArguments,
compute_rouge_metrics,
EvaluateFirstStepCallback
)

def warn(*args, **kwargs):
pass
import warnings
warnings.warn = warn
import yaml
import os

Configs

In this section, we define the configuration for this run.

We prepared 2 different configurations for this tutorial, one for with DP and one without.
The batch size are adapted so that it fits within 1 V100 GPU (16GB), or equivalently a p3.2xlarge instance.

# Name of your wandb project
os.environ["WANDB_PROJECT"] = "dynamofl-quickstart-3"
# Remove tokenizer warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Path to your yaml file containing training configs
train_config = "configs/config_nodp_lora.yml" # for the non-DP experiment, disable_dp = True
# train_config = "configs/config_dp100_lora.yml" # for the DP experiment, epsilon = 100

def read_yaml_file(filename):
with open(filename, 'r') as file:
try:
data = yaml.safe_load(file)
return data
except yaml.YAMLError as error:
print(error)
train_config_yaml = read_yaml_file(train_config)
# Toggle using low-rank adapters to speed up training and reduce memory consumption
USE_LORA = train_config_yaml['use_lora']

print("Using LORA" if USE_LORA else "Not using LORA")

Fine-tune an LLM

In this section, we will show how to use DynamoFL's ML SDK to fine-tune an LLM with differential privacy.
For benchmarking purposes, we also provide a disable_dp (see config) flag to disable differential privacy.

'''
Normal training routine imports
'''
import transformers
import os
import time
from transformers import (
set_seed,
Seq2SeqTrainer,
DataCollatorForSeq2Seq,
Seq2SeqTrainingArguments
)
from peft import LoraConfig, TaskType, get_peft_model
def train(args: DPArguments):
'''
Differential Privacy and training args that are customized by the train_config
'''
train_args = args.train
privacy_args = args.privacy
model_args = args.model
data_args = args.data

setup_logging(train_args)
# Set seed before initializing model.
set_seed(train_args.seed)

'''
Load in model and tokenizer
'''
model, tokenizer = seq2seq_model_loading(model_args)
tokenizer.pad_token = tokenizer.eos_token

model = model.cuda()
model.train()

'''
Load in dataset and parse it for the correct format
'''
train_dataset, test_dataset, _ = prepare_data_for_training(data_args, model_args, train_args, model, tokenizer)

metrics_fn = lambda x: compute_rouge_metrics(x, tokenizer=tokenizer)

'''
LoRA configs
'''
if USE_LORA:
lora_config = LoraConfig(
task_type=TaskType.SEQ_2_SEQ_LM,
inference_mode=False,
r=train_config_yaml['lora_rank'],
lora_alpha=train_config_yaml['lora_alpha'],
lora_dropout=train_config_yaml['lora_dropout'],
target_modules=train_config_yaml['lora_target_modules']
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

'''
[DynamoFL] Easily set your trainer for differentially private training with one line of code
'''
if privacy_args.disable_dp:
# Normal training, no DP (for comparison)
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
model=model,
)
trainer = Seq2SeqTrainer(
model=model,
tokenizer=tokenizer,
args=train_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
data_collator=data_collator,
compute_metrics=metrics_fn
)
else:
# Differentially Private trainer with dynamofl-privacy sdk
trainer = get_dp_seq2seq_trainer(
model=model,
tokenizer=tokenizer,
train_args=train_args,
privacy_args=privacy_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
compute_metrics=metrics_fn
)

trainer.add_callback(EvaluateFirstStepCallback())

trainer.train()

if not privacy_args.disable_dp:
'''
Log differential privacy metrics
'''
eps_prv = trainer.get_prv_epsilon()
eps_rdp = trainer.get_rdp_epsilon()
trainer.log({
"final_epsilon_prv": eps_prv, # estimated upper bound of true epsilon
"final_epsilon_rdp": eps_rdp # privacy budget (epsilon) expended
})

return trainer
def save_model(trainer, privacy_args):
'''
Save model
'''

print(f"Saving model to {trainer.args.output_dir}")
if privacy_args.disable_dp:
trainer.model.save_pretrained(trainer.args.output_dir)
else:
trainer.model._module.save_pretrained(trainer.args.output_dir)

If you already have a trained model, and you only want to run the privacy attacks, simply load the arguments (next cell) and skip to the next section.

arg_parser = transformers.HfArgumentParser((Seq2SeqTrainingArguments, PrivacyArguments, ModelArguments, Seq2SeqDataTrainingArguments))
train_args, privacy_args, model_args, data_args = arg_parser.parse_yaml_file(yaml_file=os.path.abspath(train_config), allow_extra_keys=True)
args = DPArguments(train=train_args, privacy=privacy_args, model=model_args, data=data_args)

The next cell will start the training process. You will be prompted to enter a wandb key (recommended for keeping track of the train and test metrics).

If you're not willing to use wandb, please change report_to: wandb into report_to: none in the config file (configs/ folder).

trainer = train(args)
save_model(trainer, privacy_args)

Assess LLM's vulnerability to privacy attacks

After training an LLM, we can assess its vulnerability to privacy attacks.
For this quickstart, we will showcase vulnerability to Membership Inference Attacks (MIAs), which is a specific type of privacy attack that aims to determine whether a given data point was used to train the model.

def launch_attack(args: DPArguments):
'''
Pen-Testing with DynamoFL: membership inference attack
'''
from dynamofl import DynamoFL, GPUConfig
from dotenv import load_dotenv
load_dotenv()

train_args = args.train
privacy_args = args.privacy
model_args = args.model
data_args = args.data

API_KEY = os.getenv("API_KEY", "")
API_HOST = os.getenv("API_HOST", "")
SLUG = time.time()
dfl = DynamoFL(API_KEY, host=API_HOST)

model_file_path = "adapter_model.bin" if USE_LORA else "pytorch_model.bin"
model_file_path = os.path.join(train_args.output_dir, model_file_path)

model = dfl.create_model(
name=train_args.output_dir,
model_file_path=model_file_path,
architecture=model_args.model_name_or_path,
peft_config_path=os.path.join(train_args.output_dir, "adapter_config.json") if USE_LORA else None,
)

dataset = dfl.create_hf_dataset(
name="Train and test dataset",
key=f"dataset_{SLUG}",
hf_id=data_args.dataset_name,
hf_token=model_args.use_auth_token,
)


test_info = dfl.create_membership_inference_test(
name="Membership Inference Test",
model_key=model.key,
dataset_id=dataset._id,
gpu=GPUConfig(gpu_type="v100", gpu_count=2),
input_column=data_args.text_column,
reference_column=data_args.summary_column,
base_model=model_args.model_name_or_path,
hf_token=model_args.use_auth_token,
)

def query_attack_status(attack_id):
attack_info = dfl.get_attack_info(attack_id)
while attack_info["status"] != "ERROR" and attack_info["status"] != "COMPLETED":
print(
"Attack status: {}. Retrying after 20 seconds...".format(
attack_info["status"]
)
)
time.sleep(20)
attack_info = dfl.get_attack_info(attack_id)
return attack_info

print(f"Status of attack {query_attack_status(test_info.attacks[0]['id'])}")
launch_attack(args)