-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Store the image moderation and text moderation logs #3478
base: main
Are you sure you want to change the base?
Changes from all commits
a71e3c6
68023e1
cb4da0d
605add3
4492299
2723660
51f9a0d
38a1360
e10d11b
5159d3b
dba425f
d289be9
7911ecd
1527aac
36c67da
1ccbe8b
b11f710
571f39e
3555d01
fe45c6f
a2200e4
c90b8fc
807b66f
24ce7b7
a25bd4d
c6c284e
87b6390
4c9c98f
37f3a0c
d7a152a
2ef314b
5b1fa5e
add072b
2f9d4e2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,24 +33,27 @@ | |
acknowledgment_md, | ||
get_ip, | ||
get_model_description_md, | ||
_write_to_json, | ||
) | ||
from fastchat.serve.moderation.moderator import AzureAndOpenAIContentModerator | ||
from fastchat.serve.remote_logger import get_remote_logger | ||
from fastchat.utils import ( | ||
build_logger, | ||
moderation_filter, | ||
) | ||
|
||
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") | ||
|
||
num_sides = 2 | ||
enable_moderation = False | ||
use_remote_storage = False | ||
anony_names = ["", ""] | ||
models = [] | ||
|
||
|
||
def set_global_vars_anony(enable_moderation_): | ||
global enable_moderation | ||
def set_global_vars_anony(enable_moderation_, use_remote_storage_): | ||
global enable_moderation, use_remote_storage | ||
enable_moderation = enable_moderation_ | ||
use_remote_storage = use_remote_storage_ | ||
|
||
|
||
def load_demo_side_by_side_anony(models_, url_params): | ||
|
@@ -215,6 +218,9 @@ def get_battle_pair( | |
if len(models) == 1: | ||
return models[0], models[0] | ||
|
||
if len(models) == 0: | ||
raise ValueError("There are no models provided. Cannot get battle pair.") | ||
|
||
model_weights = [] | ||
for model in models: | ||
weight = get_sample_weight( | ||
|
@@ -311,7 +317,11 @@ def add_text( | |
all_conv_text = ( | ||
all_conv_text_left[-1000:] + all_conv_text_right[-1000:] + "\nuser: " + text | ||
) | ||
flagged = moderation_filter(all_conv_text, model_list, do_moderation=True) | ||
|
||
content_moderator = AzureAndOpenAIContentModerator() | ||
flagged = content_moderator.text_moderation_filter( | ||
all_conv_text, model_list, do_moderation=True | ||
) | ||
if flagged: | ||
logger.info(f"violate moderation (anony). ip: {ip}. text: {text}") | ||
# overwrite the original text | ||
|
@@ -364,18 +374,50 @@ def bot_response_multi( | |
request: gr.Request, | ||
): | ||
logger.info(f"bot_response_multi (anony). ip: {get_ip(request)}") | ||
states = [state0, state1] | ||
|
||
if states[0] is None or states[0].skip_next: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe we can use this variable
|
||
if ( | ||
states[0].content_moderator.text_flagged | ||
or states[0].content_moderator.nsfw_flagged | ||
): | ||
for i in range(num_sides): | ||
# This generate call is skipped due to invalid inputs | ||
start_tstamp = time.time() | ||
finish_tstamp = start_tstamp | ||
states[i].conv.save_new_images( | ||
has_csam_images=states[i].has_csam_image, | ||
use_remote_storage=use_remote_storage, | ||
) | ||
|
||
filename = get_conv_log_filename( | ||
is_vision=states[i].is_vision, | ||
has_csam_image=states[i].has_csam_image, | ||
) | ||
|
||
_write_to_json( | ||
filename, | ||
start_tstamp, | ||
finish_tstamp, | ||
states[i], | ||
temperature, | ||
top_p, | ||
max_new_tokens, | ||
request, | ||
) | ||
|
||
# Remove the last message: the user input | ||
states[i].conv.messages.pop() | ||
states[i].content_moderator.update_last_moderation_response(None) | ||
|
||
if state0 is None or state0.skip_next: | ||
# This generate call is skipped due to invalid inputs | ||
yield ( | ||
state0, | ||
state1, | ||
state0.to_gradio_chatbot(), | ||
state1.to_gradio_chatbot(), | ||
states[0], | ||
states[1], | ||
states[0].to_gradio_chatbot(), | ||
states[1].to_gradio_chatbot(), | ||
) + (no_change_btn,) * 6 | ||
return | ||
|
||
states = [state0, state1] | ||
gen = [] | ||
for i in range(num_sides): | ||
gen.append( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,22 +28,27 @@ | |
acknowledgment_md, | ||
get_ip, | ||
get_model_description_md, | ||
_write_to_json, | ||
show_vote_button, | ||
dont_show_vote_button, | ||
) | ||
from fastchat.serve.moderation.moderator import AzureAndOpenAIContentModerator | ||
from fastchat.serve.remote_logger import get_remote_logger | ||
from fastchat.utils import ( | ||
build_logger, | ||
moderation_filter, | ||
) | ||
|
||
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") | ||
|
||
num_sides = 2 | ||
enable_moderation = False | ||
use_remote_storage = False | ||
|
||
|
||
def set_global_vars_named(enable_moderation_): | ||
global enable_moderation | ||
def set_global_vars_named(enable_moderation_, use_remote_storage_): | ||
global enable_moderation, use_remote_storage | ||
enable_moderation = enable_moderation_ | ||
use_remote_storage = use_remote_storage_ | ||
|
||
|
||
def load_demo_side_by_side_named(models, url_params): | ||
|
@@ -175,19 +180,27 @@ def add_text( | |
no_change_btn, | ||
] | ||
* 6 | ||
+ [dont_show_vote_button] | ||
) | ||
|
||
model_list = [states[i].model_name for i in range(num_sides)] | ||
all_conv_text_left = states[0].conv.get_prompt() | ||
all_conv_text_right = states[1].conv.get_prompt() | ||
all_conv_text = ( | ||
all_conv_text_left[-1000:] + all_conv_text_right[-1000:] + "\nuser: " + text | ||
) | ||
flagged = moderation_filter(all_conv_text, model_list) | ||
if flagged: | ||
logger.info(f"violate moderation (named). ip: {ip}. text: {text}") | ||
# overwrite the original text | ||
text = MODERATION_MSG | ||
text_flagged = states[0].content_moderator.text_moderation_filter(text, model_list) | ||
|
||
if text_flagged: | ||
logger.info(f"violate moderation. ip: {ip}. text: {text}") | ||
for i in range(num_sides): | ||
states[i].skip_next = True | ||
gr.Warning(MODERATION_MSG) | ||
return ( | ||
states | ||
+ [x.to_gradio_chatbot() for x in states] | ||
+ [""] | ||
+ [ | ||
no_change_btn, | ||
] | ||
* 6 | ||
+ [dont_show_vote_button] | ||
) | ||
|
||
conv = states[0].conv | ||
if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: | ||
|
@@ -202,6 +215,7 @@ def add_text( | |
no_change_btn, | ||
] | ||
* 6 | ||
+ [dont_show_vote_button] | ||
) | ||
|
||
text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off | ||
|
@@ -218,6 +232,7 @@ def add_text( | |
disable_btn, | ||
] | ||
* 6 | ||
+ [show_vote_button] | ||
) | ||
|
||
|
||
|
@@ -231,17 +246,49 @@ def bot_response_multi( | |
): | ||
logger.info(f"bot_response_multi (named). ip: {get_ip(request)}") | ||
|
||
if state0.skip_next: | ||
# This generate call is skipped due to invalid inputs | ||
states = [state0, state1] | ||
if states[0].skip_next: | ||
if ( | ||
states[0].content_moderator.text_flagged | ||
or states[0].content_moderator.nsfw_flagged | ||
): | ||
for i in range(num_sides): | ||
# This generate call is skipped due to invalid inputs | ||
start_tstamp = time.time() | ||
finish_tstamp = start_tstamp | ||
states[i].conv.save_new_images( | ||
has_csam_images=states[i].has_csam_image, | ||
use_remote_storage=use_remote_storage, | ||
) | ||
|
||
filename = get_conv_log_filename( | ||
is_vision=states[i].is_vision, | ||
has_csam_image=states[i].has_csam_image, | ||
) | ||
|
||
_write_to_json( | ||
filename, | ||
start_tstamp, | ||
finish_tstamp, | ||
states[i], | ||
temperature, | ||
top_p, | ||
max_new_tokens, | ||
request, | ||
) | ||
|
||
# Remove the last message: the user input | ||
states[i].conv.messages.pop() | ||
states[i].content_moderator.update_last_moderation_response(None) | ||
|
||
yield ( | ||
state0, | ||
state1, | ||
state0.to_gradio_chatbot(), | ||
state1.to_gradio_chatbot(), | ||
states[0], | ||
states[1], | ||
states[0].to_gradio_chatbot(), | ||
states[1].to_gradio_chatbot(), | ||
) + (no_change_btn,) * 6 | ||
return | ||
|
||
states = [state0, state1] | ||
gen = [] | ||
for i in range(num_sides): | ||
gen.append( | ||
|
@@ -301,14 +348,19 @@ def bot_response_multi( | |
break | ||
|
||
|
||
def flash_buttons(): | ||
def flash_buttons(show_vote_buttons: bool = True): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry could you say more what's this for? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We shouldn't flash vote buttons if the text fails the moderation test. essentially, people shouldn't be able to vote if it fails since there will be no output |
||
btn_updates = [ | ||
[disable_btn] * 4 + [enable_btn] * 2, | ||
[enable_btn] * 6, | ||
] | ||
for i in range(4): | ||
yield btn_updates[i % 2] | ||
time.sleep(0.3) | ||
|
||
if show_vote_buttons: | ||
for i in range(4): | ||
yield btn_updates[i % 2] | ||
time.sleep(0.3) | ||
else: | ||
yield [no_change_btn] * 4 + [enable_btn] * 2 | ||
return | ||
|
||
|
||
def build_side_by_side_ui_named(models): | ||
|
@@ -328,6 +380,7 @@ def build_side_by_side_ui_named(models): | |
states = [gr.State() for _ in range(num_sides)] | ||
model_selectors = [None] * num_sides | ||
chatbots = [None] * num_sides | ||
show_vote_buttons = gr.State(True) | ||
|
||
notice = gr.Markdown(notice_markdown, elem_id="notice_markdown") | ||
|
||
|
@@ -489,24 +542,24 @@ def build_side_by_side_ui_named(models): | |
textbox.submit( | ||
add_text, | ||
states + model_selectors + [textbox], | ||
states + chatbots + [textbox] + btn_list, | ||
states + chatbots + [textbox] + btn_list + [show_vote_buttons], | ||
).then( | ||
bot_response_multi, | ||
states + [temperature, top_p, max_output_tokens], | ||
states + chatbots + btn_list, | ||
).then( | ||
flash_buttons, [], btn_list | ||
flash_buttons, [show_vote_buttons], btn_list | ||
) | ||
send_btn.click( | ||
add_text, | ||
states + model_selectors + [textbox], | ||
states + chatbots + [textbox] + btn_list, | ||
states + chatbots + [textbox] + btn_list + [show_vote_buttons], | ||
).then( | ||
bot_response_multi, | ||
states + [temperature, top_p, max_output_tokens], | ||
states + chatbots + btn_list, | ||
).then( | ||
flash_buttons, [], btn_list | ||
flash_buttons, [show_vote_buttons], btn_list | ||
) | ||
|
||
return states + model_selectors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need this be global variable? also should it be globally
False
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think globally False is the correct decision because we have a set bucket where we place images and not everyone will do it that way - i think that having it default False makes it so anyone can run this without google cloud storage.