PII Obfuscation¶
This notebook demonstrates how PII (Personally Identifiable Information) obfuscation can work within the context of an LLM Agent.
Overview¶
The goal is to prevent the LLM from seeing any PII. We achieve this by following this flow:
- Mask the raw user message upon arrival.
- Pass the masked message (along with chat history) to the LLM.
- The LLM invokes a tool using PII tokens instead of actual PII.
- Inside the tool, access the vault to unmask the PII.
- Invoke the tool with real user data and get a response.
- Mask the response and return it to the LLM, allowing it to respond using tokens.
- Unmask the final response before displaying it to the user.
Setup¶
First, install the required libraries.
# Install required libraries
%pip install -qU llm-guard langgraph langchain-core langchain-openai python-dotenv
Load environment variables (e.g., API keys) from a .env
file.
# Load environment variables from a .env file
from dotenv import load_dotenv
load_dotenv()
# Import Vault and Anonymize classes from llm-guard
from llm_guard.vault import Vault
from llm_guard.input_scanners import Anonymize
# Initialize a vault to store PII and create a scanner for PII anonymization
vault = Vault()
scanner = Anonymize(vault)
Example: Masking PII¶
# Example: Mask PII in a sample string
response = scanner.scan("Hi my email is johnsmith@gmail.com")
print(f"Sanitized Prompt: {response[0]}")
# The PII is stored in the vault
scanner._vault.get()
Adding More PII to the Vault¶
# Subsequent scans will add new PII to the vault
response = scanner.scan("My name is John Smith")
scanner._vault.get()
Unmasking Function¶
We define a function to unmask text using the PII stored in the vault.
# Define a function to unmask text using the PII stored in the vault
def unmask(scanner: Anonymize, text: str):
# Retrieve the list of PII entities from the vault
entities = scanner._vault.get()
# Loop through the entities and replace the tokens with the original PII strings
for token, original_pii in entities:
text = text.replace(token, original_pii)
return text
Creating the Account Lookup Tool¶
We create a mock account lookup function and wrap it with a @tool
decorator for use in LangGraph.
The tool:
- Unmasks the input arguments.
- Performs the account lookup (mocked).
- Masks the output.
- Returns the masked output.
import json
from langchain_core.tools import tool
@tool
def account_lookup(masked_name: str, masked_email: str):
"""
Look up a user's account information based on their name and email.
Expects inputs to be masked PII tokens.
Parameters:
- masked_name (str): Masked token representing the user's name.
- masked_email (str): Masked token representing the user's email.
Returns:
- dict: Masked account information.
"""
# Unmask the arguments to get the real PII for account lookup
real_name = unmask(scanner, masked_name)
real_email = unmask(scanner, masked_email)
# Mock account lookup process
print(f"Looking up account for {real_name} with email {real_email}")
mock_account_data = {
"name": masked_name,
"email": masked_email,
"username": "jsmith22",
"phone_number": "(555) 555-1234",
"address": "1234 Main St, Anytown, USA",
"account_balance": "$1,000.75"
}
print(f"Found account: {mock_account_data['username']}")
# Mask any PII in the account data before returning
# Mask the dict by scanning its JSON string representation
masked_account_str = scanner.scan(json.dumps(mock_account_data))[0]
masked_account_data = json.loads(masked_account_str)
# Manually mask fields that may not be automatically masked
# For example, mask the username and address
scanner._vault.append(("[REDACTED_USERNAME_1]", mock_account_data["username"]))
scanner._vault.append(("[REDACTED_ADDRESS_1]", mock_account_data["address"]))
masked_account_data["username"] = "[REDACTED_USERNAME_1]"
masked_account_data["address"] = "[REDACTED_ADDRESS_1]"
# Return the masked account data to the LLM
return masked_account_data
Building the Agent with LangGraph¶
We use LangGraph to build the agent. The state only needs to track the messages
since we access the vault via the scanner
object.
from typing_extensions import TypedDict
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, ToolMessage
from langchain_openai import ChatOpenAI
# Initialize the LLM with the account_lookup tool bound to it
llm_with_tools = ChatOpenAI(
model="gpt-4o-mini", temperature=0).bind_tools([account_lookup])
# Define the system prompt for the assistant
system_prompt = """You are a customer assistant agent. Your job is to look up the user's account information if they request it.
The user's personal details are masked in the transcript and replaced with tokens (e.g., '[REDACTED_NAME_1]'). Use the tokens in the
account lookup tool arguments when invoking it."""
# Define the agent's state structure
class AgentState(TypedDict):
# We want to allow overwriting of messages so that we can mask and unmask them in pre and post processing
messages: list[BaseMessage]
Defining the Nodes¶
We define the preprocessing, postprocessing, and model call nodes.
import copy
# Preprocessing node: Mask PII in the messages before sending them to the LLM
def pre_process(state: AgentState) -> AgentState:
"""
Mask PII in the messages before sending them to the LLM.
"""
# Deep copy the messages to avoid modifying the original state
messages = copy.deepcopy(state["messages"])
for message in messages:
# Replace the message content with the masked version
message.content = scanner.scan(message.content)[0]
return {"messages": messages}
# Postprocessing node: Unmask PII in the messages before returning them to the user
def post_process(state: AgentState) -> AgentState:
"""
Unmask PII in the messages before returning them to the user.
"""
# Deep copy the messages to avoid modifying the original state
messages = copy.deepcopy(state["messages"])
for message in messages:
# Replace the message content with the unmasked version
message.content = unmask(scanner, message.content)
return {"messages": messages}
# Model call node: Invoke the LLM with the masked messages
def call_model(state: AgentState) -> AgentState:
"""
Call the LLM with the masked messages.
"""
messages = state["messages"]
# Add the system prompt to the beginning of the messages array
system_message = SystemMessage(content=system_prompt)
# Invoke the LLM
response = llm_with_tools.invoke([system_message] + messages)
# Return the messages including the LLM's response
return {"messages": messages + [response]}
Custom Tool Node¶
We need to customize the tool node to return the full messages array because we overwrite messages returned by nodes. The prebuilt ToolNode
only returns the tool messages.
# Define the custom tool node
class ToolNode:
"""
A node that runs the tools requested in the last AIMessage.
"""
def __init__(self, tools: list) -> None:
self.tools_by_name = {tool.name: tool for tool in tools}
def __call__(self, inputs: dict):
# Retrieve the messages from the inputs
if messages := inputs.get("messages", []):
# Get the last message (from the AI)
message = messages[-1]
else:
raise ValueError("No message found in input")
outputs = []
# Process each tool call in the message
for tool_call in message.tool_calls:
# Invoke the tool with the provided arguments
tool_result = self.tools_by_name[tool_call["name"]].invoke(
tool_call["args"]
)
# Create a ToolMessage with the result
outputs.append(
ToolMessage(
content=json.dumps(tool_result),
name=tool_call["name"],
tool_call_id=tool_call["id"],
)
)
# Return the messages including the tool responses
return {"messages": messages + outputs}
Conditional Routing¶
We need to route to the appropriate node based on whether the LLM's output is a tool invocation or a content response.
def route_llm_output(state: AgentState) -> str:
"""
Determine the next node based on the LLM's output.
Returns:
- "tool" if the LLM invoked a tool.
- "end" if the LLM produced a content response.
"""
messages = state["messages"]
last_message = messages[-1]
if last_message.tool_calls:
return "tool"
else:
return "end"
Building the Graph¶
We assemble the nodes and define the edges to build the agent.
from langgraph.graph import StateGraph
graph_builder = StateGraph(AgentState)
# Add the custom nodes
graph_builder.add_node("Preprocess", pre_process)
graph_builder.add_node("Call Model", call_model)
graph_builder.add_node("Post Process", post_process)
# Add the custom tool node
graph_builder.add_node("Tool Call", ToolNode(tools=[account_lookup]))
# Define the edges
# After preprocessing, call the model
graph_builder.add_edge("Preprocess", "Call Model")
# After the model call, route based on the LLM's output
graph_builder.add_conditional_edges(
"Call Model",
route_llm_output,
{"tool": "Tool Call", "end": "Post Process"}
)
# After tool calls, return to the model call node
graph_builder.add_edge("Tool Call", "Call Model")
# Set entry and finish points
graph_builder.set_entry_point("Preprocess")
graph_builder.set_finish_point("Post Process")
# Initialize a new vault and scanner for a fresh session
vault = Vault()
scanner = Anonymize(vault)
# Compile the graph to create the agent
app = graph_builder.compile()
First User Message¶
# Simulate a user message containing PII
output: AgentState = app.invoke(
{"messages": [HumanMessage(content="Hi my email is johnsmith@gmail.com")]})
# Print the output state
print(output)
Second User Message¶
# Simulate a follow-up user message containing PII
output: AgentState = app.invoke(
{"messages": output["messages"] + [HumanMessage(content="Yea my name is John Smith. What is my address and account balance?")]})
# Print the output state
print(output)
Displaying the Conversation¶
# Print the conversation messages in a readable format
for message in output["messages"]:
message.pretty_print()
In the resulting trace, we can see that the PII was masked internally at each step before finally being unmasked for the output.