Skip to content

Commit

Permalink
Merge pull request #1 from soukaryag/master
Browse files Browse the repository at this point in the history
Interactive feature addition, big fixes, code abstraction and seperation
  • Loading branch information
qiyanjun authored Oct 30, 2020
2 parents be2f81c + 268a8fa commit 5a27f91
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 87 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
__pycache__/
venv/
/.vscode
results.p
1 change: 1 addition & 0 deletions Procfile
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
web: sh setup.sh && streamlit run app.py
175 changes: 88 additions & 87 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import streamlit as st
import numpy as np

from argparse import Namespace

Expand All @@ -10,21 +11,15 @@
import random
import re

logger = logging.getLogger(__name__)
from models.html_helper import HtmlHelper
from models.args import Args
from models.cache import Cache

INITIAL_INSTRUCTIONS_HTML = """<p style="font-size:1.em; font-weight: 300">👋 Welcome to the TextAttack demo app! Please select a model and an attack recipe from the dropdown.</p> <hr style="margin: 1.em 0;">"""
logger = logging.getLogger(__name__)

from config import NUM_SAMPLES_TO_ATTACK, MODELS, ATTACK_RECIPES, HIDDEN_ATTACK_RECIPES, PRECOMPUTED_RESULTS_DICT_NAME
from config import NUM_SAMPLES_TO_ATTACK, MODELS, ATTACK_RECIPES, HIDDEN_ATTACK_RECIPES, PRECOMPUTED_RESULTS_DICT_NAME, HISTORY

def load_precomputed_results():
try:
precomputed_results = pickle.load(open(PRECOMPUTED_RESULTS_DICT_NAME, "rb" ))
except FileNotFoundError:
precomputed_results = {}
print(f'Found {len(precomputed_results)} keys in pre-computed results.')
return precomputed_results

def load_attack(model_name, attack_recipe_name):
def load_attack(model_name, attack_recipe_name, num_examples):
# Load model.
model_class_name = MODELS[model_name][0]
logger.info(f"Loading transformers.AutoModelForSequenceClassification from '{model_class_name}'.")
Expand All @@ -34,54 +29,24 @@ def load_attack(model_name, attack_recipe_name):
except OSError:
logger.warn('Couldn\'t find tokenizer; defaulting to "bert-base-uncased".')
tokenizer = textattack.models.tokenizers.AutoTokenizer("bert-base-uncased")
setattr(model, "tokenizer", tokenizer)
model = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)

# Load attack.
logger.info(f"Loading attack from recipe {attack_recipe_name}.")
attack = eval(f"{ATTACK_RECIPES[attack_recipe_name]}(model)")
attack = eval(f"{ATTACK_RECIPES[attack_recipe_name]}.build(model)")

# Load dataset.
_, dataset_args = MODELS[model_name]
dataset = textattack.datasets.HuggingFaceNlpDataset(
dataset = textattack.datasets.HuggingFaceDataset(
*dataset_args, shuffle=True
)
dataset.examples = dataset.examples[:num_examples]
return model, attack, dataset

def improve_result_html(result_html):
result_html = result_html.replace("color = bold", 'style="font-weight: bold;"')
result_html = result_html.replace("color = underline", 'style="text-decoration: underline;"')
result_html = result_html.replace('<font style="font-weight: bold;"', '<span style=""') # no bolding for now
result_html = result_html.replace('<font style="text-decoration: underline;"', '<span style="text-decoration: underline;"')
result_html = re.sub(r"<font\scolor\s=\s(\w.*?)>", r'<span style="background-color: \1; padding: 1.2px; font-weight: 600;">', result_html)
# replace font colors with transparent highlight versions
result_html = result_html.replace(': red', ': rgba(255, 0, 0, .7)') \
.replace(': green', ': rgb(0, 255, 0, .7)') \
.replace(': blue', ': rgb(0, 0, 255, .7)') \
.replace(': gray', ': rgb(220, 220, 220, .7)')
result_html = result_html.replace("</font>", "</span>")
return result_html

def get_attack_result_status(attack_result):
status_html = attack_result.goal_function_result_str(color_method='html')
return improve_result_html(status_html)

def get_attack_result_html(idx, attack_result):
result_status = get_attack_result_status(attack_result)
result_html_lines = attack_result.str_lines(color_method='html')
result_html_lines = [improve_result_html(line) for line in result_html_lines]
rows = [
['', result_status],
['Input', result_html_lines[1]]
]

if len(result_html_lines) > 2:
rows.append(['Output', result_html_lines[2]])

table_html = '\n'.join((f'<b>{header}:</b> {body} <br>' if header else f'{body} <br>') for header,body in rows)
return f'<h3>Result {idx+1}</h3> {table_html} <br>'

@st.cache
def get_attack_recipe_prototype(attack_recipe_name):
""" a sort of hacky way to print an attack recipe without loading a big model"""
recipe = eval(textattack.commands.attack.attack_args.ATTACK_RECIPE_NAMES[attack_recipe_name])
recipe = eval(textattack.commands.attack.attack_args.ATTACK_RECIPE_NAMES[attack_recipe_name]).build
dummy_tokenizer = Namespace(**{ 'encode': None})
dummy_model = Namespace(**{ 'tokenizer': dummy_tokenizer })
recipe = recipe(dummy_model)
Expand All @@ -91,70 +56,81 @@ def get_attack_recipe_prototype(attack_recipe_name):
del dummy_tokenizer
return recipe_str

def display_history(fake_latency=False):
history = PRECOMPUTE_CACHE.get(HISTORY)
for idx, result in enumerate(history):
if fake_latency: random_latency()
st.markdown(HtmlHelper.get_attack_result_html(idx, result), unsafe_allow_html=True)

def random_latency():
# Artificially inject a tiny bit of latency to provide
# a feel of the attack _running_.
time.sleep(random.triangular(0., 2., .8))

@st.cache(suppress_st_warning=True,allow_output_mutation=True)
def get_and_print_attack_results(model_name, attack_recipe_name):
def get_and_print_attack_results(model_name, attack_recipe_name, num_examples):
with st.spinner(f'Loading `{model_name}` model and `{attack_recipe_name}` attack...'):
model, attack, dataset = load_attack(model_name, attack_recipe_name)
model, attack, dataset = load_attack(model_name, attack_recipe_name, num_examples)
dataset_name = dataset._name

# Run attack.
from collections import deque
worklist = deque(range(0, NUM_SAMPLES_TO_ATTACK))
worklist = deque(range(0, num_examples))
results = []
with st.spinner(f'Running attack on {NUM_SAMPLES_TO_ATTACK} samples from nlp dataset "{dataset_name}"...'):
with st.spinner(f'Running attack on {num_examples} samples from nlp dataset "{dataset_name}"...'):
for idx, result in enumerate(attack.attack_dataset(dataset, indices=worklist)):
st.markdown(get_attack_result_html(idx, result), unsafe_allow_html=True)
st.markdown(HtmlHelper.get_attack_result_html(idx, result), unsafe_allow_html=True)
results.append(result)

# Update precomputed results
PRECOMPUTED_RESULTS = load_precomputed_results()
PRECOMPUTED_RESULTS[(model_name, attack_recipe_name)] = results
pickle.dump(PRECOMPUTED_RESULTS, open(PRECOMPUTED_RESULTS_DICT_NAME, 'wb'))
# Return results
return { 'results': results, 'already_printed': True }

def random_latency():
# Artificially inject a tiny bit of latency to provide
# a feel of the attack _running_.
time.sleep(random.triangular(0., 2., .8))
# Update precomputed results
PRECOMPUTE_CACHE.add((model_name, attack_recipe_name), results)

def run_attack(model_name, attack_recipe_name):
if (model_name, attack_recipe_name) in PRECOMPUTED_RESULTS:
results = PRECOMPUTED_RESULTS[(model_name, attack_recipe_name)]
for idx, result in enumerate(results):
random_latency()
st.markdown(get_attack_result_html(idx, result), unsafe_allow_html=True)
def run_attack_interactive(text, model_name, attack_recipe_name):
if PRECOMPUTE_CACHE.exists((text, model_name, attack_recipe_name)) and PRECOMPUTE_CACHE.exists(HISTORY):
PRECOMPUTE_CACHE.to_top((text, model_name, attack_recipe_name))
display_history(fake_latency=True)
else:
# Precompute results
results_dict = get_and_print_attack_results(model_name, attack_recipe_name)
results = results_dict['results']
# Print attack results, as long as this wasn't the first time they were computed.
if not results_dict['already_printed']:
for idx, result in enumerate(results):
random_latency()
st.markdown(get_attack_result_html(idx, result), unsafe_allow_html=True)
results_dict['already_printed'] = False
# print summary
attack = textattack.commands.attack.attack_args_helpers.parse_attack_from_args(Args(model_name, attack_recipe_name))
attacked_text = textattack.shared.attacked_text.AttackedText(text)
initial_result = attack.goal_function.get_output(attacked_text)
result = next(attack.attack_dataset([(text, initial_result)]))

# Update precomputed results
PRECOMPUTE_CACHE.add((text, model_name, attack_recipe_name), result)
display_history()

def run_attack(model_name, attack_recipe_name, num_examples):
if PRECOMPUTE_CACHE.exists((model_name, attack_recipe_name)):
PRECOMPUTE_CACHE.to_top((model_name, attack_recipe_name))
display_history(fake_latency=True)
else:
get_and_print_attack_results(model_name, attack_recipe_name, num_examples)


def process_attack_recipe_doc(attack_recipe_text):
attack_recipe_text = attack_recipe_text.strip()
attack_recipe_text = "\n".join(map(lambda line: line.strip(), attack_recipe_text.split("\n")))
return attack_recipe_text

def main():
# Print instructions.
st.markdown(INITIAL_INSTRUCTIONS_HTML, unsafe_allow_html=True)
st.beta_set_page_config(page_title='TextAttack Web Demo', page_icon='https://cdn.shopify.com/s/files/1/1061/1924/products/Octopus_Iphone_Emoji_JPG_large.png', initial_sidebar_state ='auto')
st.markdown(HtmlHelper.INITIAL_INSTRUCTIONS_HTML, unsafe_allow_html=True)

# Print TextAttack info to sidebar.
st.sidebar.markdown('<h1 style="text-align:center; font-size: 1.5em;">TextAttack 🐙</h1>', unsafe_allow_html=True)
st.sidebar.markdown('<p style="font-size:1.em; text-align:center;"><a href="https://github.com/QData/TextAttack">https://github.com/QData/TextAttack</a></p>', unsafe_allow_html=True)
st.sidebar.markdown('<hr>', unsafe_allow_html=True)

# Select model.
all_model_names = list(re.sub(r'-mr$', '-rotten_tomatoes', m) for m in MODELS.keys())
model_names = list(sorted(set(map(lambda x: x.replace(x[x.rfind('-'):],''), all_model_names))))
model_default = 'bert-base-uncased'
model_default_index = model_names.index(model_default)
interactive = st.sidebar.checkbox('Interactive')
model_name = st.sidebar.selectbox('Model from transformers:', model_names, index=model_default_index)

# Select dataset. (TODO make this less hacky.)
if interactive:
interactive_text = st.sidebar.text_input('Custom Input Data')
matching_model_keys = list(m for m in all_model_names if m.startswith(model_name))
dataset_names = list(sorted(map(lambda x: x.replace(x[:x.rfind('-')+1],''), matching_model_keys)))
dataset_default_index = 0
Expand All @@ -166,28 +142,53 @@ def main():
continue
dataset_name = st.sidebar.selectbox('Dataset from nlp:', dataset_names, index=dataset_default_index)
full_model_name = '-'.join((model_name, dataset_name)).replace('-rotten_tomatoes', '-mr')

# Select attack recipe.
recipe_names = list(sorted(ATTACK_RECIPES.keys()))
for hidden_attack in HIDDEN_ATTACK_RECIPES: recipe_names.remove(hidden_attack)
recipe_default = 'textfooler'
recipe_default_index = recipe_names.index(recipe_default)
attack_recipe_name = st.sidebar.selectbox('Attack recipe', recipe_names, index=recipe_default_index)

# Select number of examples to be displayed
if not interactive:
num_examples = st.sidebar.slider('Number of Examples', 1, 100, value=10, step=1)

# Run attack on button press.
if st.sidebar.button('Run attack'):
# Run full attack.
run_attack(full_model_name, attack_recipe_name)
if interactive: run_attack_interactive(interactive_text, full_model_name, attack_recipe_name)
else: run_attack(full_model_name, attack_recipe_name, num_examples)
else:
# Print History of Usage
timeline_history = PRECOMPUTE_CACHE.get(HISTORY)
for idx, entry in enumerate(timeline_history):
st.markdown(HtmlHelper.get_attack_result_html(idx, entry), unsafe_allow_html=True)

# Display clear history button
if PRECOMPUTE_CACHE.exists(HISTORY):
clear_history = st.button("Clear History")
if clear_history:
PRECOMPUTE_CACHE.purge(key=HISTORY)

# TODO print attack metrics somewhere?
# Add model info to sidebar.
hf_model_name = MODELS[full_model_name][0]
model_link = f"https://huggingface.co/{hf_model_name}"
st.markdown(f"### Model Hub Link \n [[{hf_model_name}]({model_link})]", unsafe_allow_html=True)

# Add attack info to sidebar (TODO would main page be better?).
attack_recipe_doc = process_attack_recipe_doc(eval(f"{ATTACK_RECIPES[attack_recipe_name]}.__doc__"))
st.sidebar.markdown(f'<hr style="margin: 1.0em 0;"> <h3>Attack Recipe:</h3>\n<b>Name:</b> {attack_recipe_name} <br> <br> {attack_recipe_doc}', unsafe_allow_html=True)

# Print attack recipe composition
attack_recipe_prototype = get_attack_recipe_prototype(attack_recipe_name)
st.markdown(f'### Attack Recipe Prototype \n```\n{attack_recipe_prototype}\n```')

purge_cache = st.button("Purge Local Cache")
if purge_cache:
PRECOMPUTE_CACHE.purge()

if __name__ == "__main__": # @TODO split model & dataset into 2 dropdowns
PRECOMPUTED_RESULTS = load_precomputed_results()
if __name__ == "__main__":
PRECOMPUTE_CACHE = Cache(log=False)
main()
1 change: 1 addition & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
HIDDEN_ATTACK_RECIPES = ['alzantot', 'seq2sick', 'hotflip']

PRECOMPUTED_RESULTS_DICT_NAME = 'results.p'
HISTORY = 'timeline_history'
11 changes: 11 additions & 0 deletions models/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
class Args():
def __init__(self, model, recipe, model_batch_size=32, query_budget=200, model_cache_size=2**18, constraint_cache_size=2**18):
self.model = model
self.recipe = recipe
self.model_batch_size = model_batch_size
self.model_cache_size = model_cache_size
self.query_budget = query_budget
self.constraint_cache_size = constraint_cache_size

def __getattr__(self, item):
return False
60 changes: 60 additions & 0 deletions models/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import pickle

from config import PRECOMPUTED_RESULTS_DICT_NAME, HISTORY

class Cache():
def __init__(self, log=False):
self.log = log
self.cache = self.load_precomputed_results()

def load_precomputed_results(self):
try:
precomputed_results = pickle.load(open(PRECOMPUTED_RESULTS_DICT_NAME, "rb"))
except FileNotFoundError:
precomputed_results = {}
if self.log: print(f'Found {len(precomputed_results)} keys in pre-computed results.')
return precomputed_results

def add(self, key, data):
self.cache = self.load_precomputed_results()
self.cache[key] = data

# update history
if isinstance(data, list):
self.cache[HISTORY] = data + self.cache.get(HISTORY, [])
else:
self.cache[HISTORY] = [data] + self.cache.get(HISTORY, [])

pickle.dump(self.cache, open(PRECOMPUTED_RESULTS_DICT_NAME, 'wb'))
if self.log: print(f'Successfully added {key} to the cache')

def to_top(self, key):
self.cache = self.load_precomputed_results()
data, history = self.cache.get(key, None), self.cache.get(HISTORY, None)
if not data or not history:
return []

if isinstance(data, list):
for d in data:
history.pop(history.index(d))
history.insert(0, d)
else:
history.pop(history.index(data))
history.insert(0, data)

def exists(self, key):
self.cache = self.load_precomputed_results()
return key in self.cache

def purge(self, key=None):
self.cache = self.load_precomputed_results()
if not key:
self.cache.clear()
elif key in self.cache:
del self.cache[key]
if self.log: print(f'Cache successfully purged')
pickle.dump(self.cache, open(PRECOMPUTED_RESULTS_DICT_NAME, 'wb'))

def get(self, key):
self.cache = self.load_precomputed_results()
return self.cache.get(key, [])
Loading

0 comments on commit 5a27f91

Please sign in to comment.