diff --git a/fastchat/conversation.py b/fastchat/conversation.py index 4a46103ec..03627a7b6 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -614,7 +614,11 @@ def save_new_images(self, has_csam_images=False, use_remote_storage=False): from fastchat.utils import load_image, upload_image_file_to_gcs from PIL import Image - _, last_user_message = self.messages[-2] + last_user_message = None + for role, message in reversed(self.messages): + if role == "user": + last_user_message = message + break if type(last_user_message) == tuple: text, images = last_user_message[0], last_user_message[1] diff --git a/fastchat/serve/gradio_block_arena_anony.py b/fastchat/serve/gradio_block_arena_anony.py index 625c69c44..e98f5d74e 100644 --- a/fastchat/serve/gradio_block_arena_anony.py +++ b/fastchat/serve/gradio_block_arena_anony.py @@ -33,24 +33,27 @@ acknowledgment_md, get_ip, get_model_description_md, + _write_to_json, ) +from fastchat.serve.moderation.moderator import AzureAndOpenAIContentModerator from fastchat.serve.remote_logger import get_remote_logger from fastchat.utils import ( build_logger, - moderation_filter, ) logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") num_sides = 2 enable_moderation = False +use_remote_storage = False anony_names = ["", ""] models = [] -def set_global_vars_anony(enable_moderation_): - global enable_moderation +def set_global_vars_anony(enable_moderation_, use_remote_storage_): + global enable_moderation, use_remote_storage enable_moderation = enable_moderation_ + use_remote_storage = use_remote_storage_ def load_demo_side_by_side_anony(models_, url_params): @@ -215,6 +218,9 @@ def get_battle_pair( if len(models) == 1: return models[0], models[0] + if len(models) == 0: + raise ValueError("There are no models provided. Cannot get battle pair.") + model_weights = [] for model in models: weight = get_sample_weight( @@ -311,7 +317,11 @@ def add_text( all_conv_text = ( all_conv_text_left[-1000:] + all_conv_text_right[-1000:] + "\nuser: " + text ) - flagged = moderation_filter(all_conv_text, model_list, do_moderation=True) + + content_moderator = AzureAndOpenAIContentModerator() + flagged = content_moderator.text_moderation_filter( + all_conv_text, model_list, do_moderation=True + ) if flagged: logger.info(f"violate moderation (anony). ip: {ip}. text: {text}") # overwrite the original text @@ -364,18 +374,50 @@ def bot_response_multi( request: gr.Request, ): logger.info(f"bot_response_multi (anony). ip: {get_ip(request)}") + states = [state0, state1] + + if states[0] is None or states[0].skip_next: + if ( + states[0].content_moderator.text_flagged + or states[0].content_moderator.nsfw_flagged + ): + for i in range(num_sides): + # This generate call is skipped due to invalid inputs + start_tstamp = time.time() + finish_tstamp = start_tstamp + states[i].conv.save_new_images( + has_csam_images=states[i].has_csam_image, + use_remote_storage=use_remote_storage, + ) + + filename = get_conv_log_filename( + is_vision=states[i].is_vision, + has_csam_image=states[i].has_csam_image, + ) + + _write_to_json( + filename, + start_tstamp, + finish_tstamp, + states[i], + temperature, + top_p, + max_new_tokens, + request, + ) + + # Remove the last message: the user input + states[i].conv.messages.pop() + states[i].content_moderator.update_last_moderation_response(None) - if state0 is None or state0.skip_next: - # This generate call is skipped due to invalid inputs yield ( - state0, - state1, - state0.to_gradio_chatbot(), - state1.to_gradio_chatbot(), + states[0], + states[1], + states[0].to_gradio_chatbot(), + states[1].to_gradio_chatbot(), ) + (no_change_btn,) * 6 return - states = [state0, state1] gen = [] for i in range(num_sides): gen.append( diff --git a/fastchat/serve/gradio_block_arena_named.py b/fastchat/serve/gradio_block_arena_named.py index 2f7b39adb..4292b0bc0 100644 --- a/fastchat/serve/gradio_block_arena_named.py +++ b/fastchat/serve/gradio_block_arena_named.py @@ -28,22 +28,27 @@ acknowledgment_md, get_ip, get_model_description_md, + _write_to_json, + show_vote_button, + dont_show_vote_button, ) +from fastchat.serve.moderation.moderator import AzureAndOpenAIContentModerator from fastchat.serve.remote_logger import get_remote_logger from fastchat.utils import ( build_logger, - moderation_filter, ) logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") num_sides = 2 enable_moderation = False +use_remote_storage = False -def set_global_vars_named(enable_moderation_): - global enable_moderation +def set_global_vars_named(enable_moderation_, use_remote_storage_): + global enable_moderation, use_remote_storage enable_moderation = enable_moderation_ + use_remote_storage = use_remote_storage_ def load_demo_side_by_side_named(models, url_params): @@ -175,19 +180,27 @@ def add_text( no_change_btn, ] * 6 + + [dont_show_vote_button] ) model_list = [states[i].model_name for i in range(num_sides)] - all_conv_text_left = states[0].conv.get_prompt() - all_conv_text_right = states[1].conv.get_prompt() - all_conv_text = ( - all_conv_text_left[-1000:] + all_conv_text_right[-1000:] + "\nuser: " + text - ) - flagged = moderation_filter(all_conv_text, model_list) - if flagged: - logger.info(f"violate moderation (named). ip: {ip}. text: {text}") - # overwrite the original text - text = MODERATION_MSG + text_flagged = states[0].content_moderator.text_moderation_filter(text, model_list) + + if text_flagged: + logger.info(f"violate moderation. ip: {ip}. text: {text}") + for i in range(num_sides): + states[i].skip_next = True + gr.Warning(MODERATION_MSG) + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [""] + + [ + no_change_btn, + ] + * 6 + + [dont_show_vote_button] + ) conv = states[0].conv if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: @@ -202,6 +215,7 @@ def add_text( no_change_btn, ] * 6 + + [dont_show_vote_button] ) text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off @@ -218,6 +232,7 @@ def add_text( disable_btn, ] * 6 + + [show_vote_button] ) @@ -231,17 +246,49 @@ def bot_response_multi( ): logger.info(f"bot_response_multi (named). ip: {get_ip(request)}") - if state0.skip_next: - # This generate call is skipped due to invalid inputs + states = [state0, state1] + if states[0].skip_next: + if ( + states[0].content_moderator.text_flagged + or states[0].content_moderator.nsfw_flagged + ): + for i in range(num_sides): + # This generate call is skipped due to invalid inputs + start_tstamp = time.time() + finish_tstamp = start_tstamp + states[i].conv.save_new_images( + has_csam_images=states[i].has_csam_image, + use_remote_storage=use_remote_storage, + ) + + filename = get_conv_log_filename( + is_vision=states[i].is_vision, + has_csam_image=states[i].has_csam_image, + ) + + _write_to_json( + filename, + start_tstamp, + finish_tstamp, + states[i], + temperature, + top_p, + max_new_tokens, + request, + ) + + # Remove the last message: the user input + states[i].conv.messages.pop() + states[i].content_moderator.update_last_moderation_response(None) + yield ( - state0, - state1, - state0.to_gradio_chatbot(), - state1.to_gradio_chatbot(), + states[0], + states[1], + states[0].to_gradio_chatbot(), + states[1].to_gradio_chatbot(), ) + (no_change_btn,) * 6 return - states = [state0, state1] gen = [] for i in range(num_sides): gen.append( @@ -301,14 +348,19 @@ def bot_response_multi( break -def flash_buttons(): +def flash_buttons(show_vote_buttons: bool = True): btn_updates = [ [disable_btn] * 4 + [enable_btn] * 2, [enable_btn] * 6, ] - for i in range(4): - yield btn_updates[i % 2] - time.sleep(0.3) + + if show_vote_buttons: + for i in range(4): + yield btn_updates[i % 2] + time.sleep(0.3) + else: + yield [no_change_btn] * 4 + [enable_btn] * 2 + return def build_side_by_side_ui_named(models): @@ -328,6 +380,7 @@ def build_side_by_side_ui_named(models): states = [gr.State() for _ in range(num_sides)] model_selectors = [None] * num_sides chatbots = [None] * num_sides + show_vote_buttons = gr.State(True) notice = gr.Markdown(notice_markdown, elem_id="notice_markdown") @@ -489,24 +542,24 @@ def build_side_by_side_ui_named(models): textbox.submit( add_text, states + model_selectors + [textbox], - states + chatbots + [textbox] + btn_list, + states + chatbots + [textbox] + btn_list + [show_vote_buttons], ).then( bot_response_multi, states + [temperature, top_p, max_output_tokens], states + chatbots + btn_list, ).then( - flash_buttons, [], btn_list + flash_buttons, [show_vote_buttons], btn_list ) send_btn.click( add_text, states + model_selectors + [textbox], - states + chatbots + [textbox] + btn_list, + states + chatbots + [textbox] + btn_list + [show_vote_buttons], ).then( bot_response_multi, states + [temperature, top_p, max_output_tokens], states + chatbots + btn_list, ).then( - flash_buttons, [], btn_list + flash_buttons, [show_vote_buttons], btn_list ) return states + model_selectors diff --git a/fastchat/serve/gradio_block_arena_vision.py b/fastchat/serve/gradio_block_arena_vision.py index b3d812220..e3e0a5f6a 100644 --- a/fastchat/serve/gradio_block_arena_vision.py +++ b/fastchat/serve/gradio_block_arena_vision.py @@ -39,11 +39,10 @@ get_conv_log_filename, get_remote_logger, ) +from fastchat.serve.moderation.moderator import AzureAndOpenAIContentModerator from fastchat.serve.vision.image import ImageFormat, Image from fastchat.utils import ( build_logger, - moderation_filter, - image_moderation_filter, ) logger = build_logger("gradio_web_server", "gradio_web_server.log") @@ -54,8 +53,16 @@ invisible_btn = gr.Button(interactive=False, visible=False) visible_image_column = gr.Image(visible=True) invisible_image_column = gr.Image(visible=False) -enable_multimodal = gr.MultimodalTextbox( - interactive=True, visible=True, placeholder="Enter your prompt or add image here" +enable_multimodal_keep_input = gr.MultimodalTextbox( + interactive=True, + visible=True, + placeholder="Enter your prompt or add image here", +) +enable_multimodal_clear_input = gr.MultimodalTextbox( + interactive=True, + visible=True, + placeholder="Enter your prompt or add image here", + value={"text": "", "files": []}, ) invisible_text = gr.Textbox(visible=False, value="", interactive=False) visible_text = gr.Textbox( @@ -140,6 +147,7 @@ def regenerate(state, request: gr.Request): state.skip_next = True return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 state.conv.update_last_message(None) + state.content_moderator.update_last_moderation_response(None) return (state, state.to_gradio_chatbot(), None) + (disable_btn,) * 5 @@ -147,7 +155,7 @@ def clear_history(request: gr.Request): ip = get_ip(request) logger.info(f"clear_history. ip: {ip}") state = None - return (state, [], enable_multimodal, invisible_text, invisible_btn) + ( + return (state, [], enable_multimodal_clear_input, invisible_text, invisible_btn) + ( disable_btn, ) * 5 @@ -156,7 +164,7 @@ def clear_history_example(request: gr.Request): ip = get_ip(request) logger.info(f"clear_history_example. ip: {ip}") state = None - return (state, [], enable_multimodal, invisible_text, invisible_btn) + ( + return (state, [], enable_multimodal_keep_input, invisible_text, invisible_btn) + ( disable_btn, ) * 5 @@ -166,7 +174,7 @@ def report_csam_image(state, image): pass -def _prepare_text_with_image(state, text, images, csam_flag): +def _prepare_text_with_image(state, text, images): if len(images) > 0: if len(state.conv.get_images()) > 0: # reset convo with new image @@ -191,38 +199,7 @@ def convert_images_to_conversation_format(images): return conv_images -def moderate_input(state, text, all_conv_text, model_list, images, ip): - text_flagged = moderation_filter(all_conv_text, model_list) - # flagged = moderation_filter(text, [state.model_name]) - nsfw_flagged, csam_flagged = False, False - if len(images) > 0: - nsfw_flagged, csam_flagged = image_moderation_filter(images[0]) - - image_flagged = nsfw_flagged or csam_flagged - if text_flagged or image_flagged: - logger.info(f"violate moderation. ip: {ip}. text: {all_conv_text}") - if text_flagged and not image_flagged: - # overwrite the original text - text = TEXT_MODERATION_MSG - elif not text_flagged and image_flagged: - text = IMAGE_MODERATION_MSG - elif text_flagged and image_flagged: - text = MODERATION_MSG - - if csam_flagged: - state.has_csam_image = True - report_csam_image(state, images[0]) - - return text, image_flagged, csam_flagged - - -def add_text( - state, - model_selector, - chat_input: Union[str, dict], - context: Context, - request: gr.Request, -): +def add_text(state, model_selector, chat_input, context: Context, request: gr.Request): if isinstance(chat_input, dict): text, images = chat_input["text"], chat_input["files"] else: @@ -235,7 +212,6 @@ def add_text( ): gr.Warning(f"{model_selector} is a text-only model. Image is ignored.") images = [] - ip = get_ip(request) logger.info(f"add_text. ip: {ip}. len: {len(text)}") @@ -256,20 +232,37 @@ def add_text( images = convert_images_to_conversation_format(images) - text, image_flagged, csam_flag = moderate_input( - state, text, all_conv_text, [state.model_name], images, ip + # Use the first state to get the moderation response because this is based on user input so it is independent of the model + moderation_image_input = images[0] if len(images) > 0 else None + moderation_type_to_response_map = ( + state.content_moderator.image_and_text_moderation_filter( + moderation_image_input, text, [state.model_name], do_moderation=False + ) + ) + + text_flagged, nsfw_flag, csam_flag = ( + state.content_moderator.text_flagged, + state.content_moderator.nsfw_flagged, + state.content_moderator.csam_flagged, ) - if image_flagged: - logger.info(f"image flagged. ip: {ip}. text: {text}") + if csam_flag: + state.has_csam_image = True + + state.content_moderator.append_moderation_response(moderation_type_to_response_map) + + if text_flagged or nsfw_flag: + logger.info(f"violate moderation. ip: {ip}. text: {text}") + gradio_chatbot_before_user_input = state.to_gradio_chatbot() + post_processed_text = _prepare_text_with_image(state, text, images) + state.conv.append_message(state.conv.roles[0], post_processed_text) state.skip_next = True + gr.Warning(MODERATION_MSG) return ( - state, - state.to_gradio_chatbot(), - {"text": IMAGE_MODERATION_MSG}, - "", - no_change_btn, - ) + (no_change_btn,) * 5 + (state, gradio_chatbot_before_user_input, None, "", no_change_btn) + + (no_change_btn,) + + (no_change_btn,) * 5 + ) if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: logger.info(f"conversation turn limit. ip: {ip}. text: {text}") @@ -283,7 +276,7 @@ def add_text( ) + (no_change_btn,) * 5 text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off - text = _prepare_text_with_image(state, text, images, csam_flag=csam_flag) + text = _prepare_text_with_image(state, text, images) state.conv.append_message(state.conv.roles[0], text) state.conv.append_message(state.conv.roles[1], None) return ( @@ -318,10 +311,7 @@ def build_single_vision_language_model_ui( state = gr.State() gr.Markdown(notice_markdown, elem_id="notice_markdown") - vision_not_in_text_models = [ - model for model in context.vision_models if model not in context.text_models - ] - text_and_vision_models = context.text_models + vision_not_in_text_models + text_and_vision_models = list(set(context.text_models + context.vision_models)) context_state = gr.State(context) with gr.Group(): diff --git a/fastchat/serve/gradio_block_arena_vision_anony.py b/fastchat/serve/gradio_block_arena_vision_anony.py index d4d4d484e..fce547580 100644 --- a/fastchat/serve/gradio_block_arena_vision_anony.py +++ b/fastchat/serve/gradio_block_arena_vision_anony.py @@ -35,6 +35,9 @@ get_model_description_md, disable_text, enable_text, + use_remote_storage, + show_vote_button, + dont_show_vote_button, ) from fastchat.serve.gradio_block_arena_anony import ( flash_buttons, @@ -60,20 +63,22 @@ set_invisible_image, set_visible_image, add_image, - moderate_input, - enable_multimodal, + enable_multimodal_keep_input, _prepare_text_with_image, convert_images_to_conversation_format, invisible_text, visible_text, disable_multimodal, + enable_multimodal_clear_input, +) +from fastchat.serve.moderation.moderator import ( + BaseContentModerator, + AzureAndOpenAIContentModerator, ) from fastchat.serve.gradio_global_state import Context from fastchat.serve.remote_logger import get_remote_logger from fastchat.utils import ( build_logger, - moderation_filter, - image_moderation_filter, ) logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") @@ -85,7 +90,20 @@ vl_models = [] # TODO(chris): fix sampling weights -VISION_SAMPLING_WEIGHTS = {} +VISION_SAMPLING_WEIGHTS = { + "gpt-4o-2024-05-13": 4, + "gpt-4-turbo-2024-04-09": 4, + "claude-3-haiku-20240307": 4, + "claude-3-sonnet-20240229": 4, + "claude-3-5-sonnet-20240620": 4, + "claude-3-opus-20240229": 4, + "gemini-1.5-flash-api-0514": 4, + "gemini-1.5-pro-api-0514": 4, + "llava-v1.6-34b": 4, + "reka-core-20240501": 4, + "reka-flash-preview-20240611": 4, + "reka-flash": 4, +} # TODO(chris): Find battle targets that make sense VISION_BATTLE_TARGETS = {} @@ -120,7 +138,7 @@ def clear_history_example(request: gr.Request): [None] * num_sides + [None] * num_sides + anony_names - + [enable_multimodal, invisible_text, invisible_btn] + + [enable_multimodal_keep_input, invisible_text, invisible_btn] + [invisible_btn] * 4 + [disable_btn] * 2 + [enable_btn] @@ -216,6 +234,7 @@ def regenerate(state0, state1, request: gr.Request): if state0.regen_support and state1.regen_support: for i in range(num_sides): states[i].conv.update_last_message(None) + states[i].content_moderator.update_last_moderation_response(None) return ( states + [x.to_gradio_chatbot() for x in states] @@ -235,7 +254,7 @@ def clear_history(request: gr.Request): [None] * num_sides + [None] * num_sides + anony_names - + [enable_multimodal, invisible_text, invisible_btn] + + [enable_multimodal_clear_input, invisible_text, invisible_btn] + [invisible_btn] * 4 + [disable_btn] * 2 + [enable_btn] @@ -305,15 +324,33 @@ def add_text( ] * 7 + [""] + + [dont_show_vote_button] ) model_list = [states[i].model_name for i in range(num_sides)] images = convert_images_to_conversation_format(images) - text, image_flagged, csam_flag = moderate_input( - state0, text, text, model_list, images, ip + # Use the first state to get the moderation response because this is based on user input so it is independent of the model + moderation_image_input = images[0] if len(images) > 0 else None + moderation_type_to_response_map = states[ + 0 + ].content_moderator.image_and_text_moderation_filter( + moderation_image_input, text, model_list, do_moderation=True ) + text_flagged, nsfw_flag, csam_flag = ( + states[0].content_moderator.text_flagged, + states[0].content_moderator.nsfw_flagged, + states[0].content_moderator.csam_flagged, + ) + + if csam_flag: + states[0].has_csam_image, states[1].has_csam_image = True, True + + for state in states: + state.content_moderator.append_moderation_response( + moderation_type_to_response_map + ) conv = states[0].conv if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: @@ -329,32 +366,34 @@ def add_text( ] * 7 + [""] + + [dont_show_vote_button] ) - if image_flagged: - logger.info(f"image flagged. ip: {ip}. text: {text}") + if text_flagged or nsfw_flag: + logger.info(f"violate moderation. ip: {ip}. text: {text}") + # We call this before appending the text so it does not appear in the UI + gradio_chatbot_list = [x.to_gradio_chatbot() for x in states] for i in range(num_sides): + post_processed_text = _prepare_text_with_image(states[i], text, images) + states[i].conv.append_message(states[i].conv.roles[0], post_processed_text) states[i].skip_next = True + gr.Warning(MODERATION_MSG) return ( states - + [x.to_gradio_chatbot() for x in states] + + gradio_chatbot_list + [ - { - "text": IMAGE_MODERATION_MSG - + " PLEASE CLICK ๐ŸŽฒ NEW ROUND TO START A NEW CONVERSATION." - }, + None, "", no_change_btn, ] - + [no_change_btn] * 7 + + [disable_btn] * 7 + [""] + + [dont_show_vote_button] ) text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off for i in range(num_sides): - post_processed_text = _prepare_text_with_image( - states[i], text, images, csam_flag=csam_flag - ) + post_processed_text = _prepare_text_with_image(states[i], text, images) states[i].conv.append_message(states[i].conv.roles[0], post_processed_text) states[i].conv.append_message(states[i].conv.roles[1], None) states[i].skip_next = False @@ -372,6 +411,7 @@ def add_text( ] * 7 + [hint_msg] + + [show_vote_button] ) @@ -398,9 +438,11 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None): states = [gr.State() for _ in range(num_sides)] model_selectors = [None] * num_sides chatbots = [None] * num_sides + show_vote_buttons = gr.State(True) + context_state = gr.State(context) gr.Markdown(notice_markdown, elem_id="notice_markdown") - text_and_vision_models = context.models + text_and_vision_models = list(set(context.text_models + context.vision_models)) with gr.Row(): with gr.Column(scale=2, visible=False) as image_column: @@ -478,6 +520,7 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None): elem_id="input_box", scale=3, ) + send_btn = gr.Button( value="Send", variant="primary", scale=1, visible=False, interactive=False ) @@ -612,14 +655,15 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None): + [multimodal_textbox, textbox, send_btn] + btn_list + [random_btn] - + [slow_warning], + + [slow_warning] + + [show_vote_buttons], ).then(set_invisible_image, [], [image_column]).then( bot_response_multi, states + [temperature, top_p, max_output_tokens], states + chatbots + btn_list, ).then( flash_buttons, - [], + [show_vote_buttons], btn_list, ) @@ -631,14 +675,15 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None): + [multimodal_textbox, textbox, send_btn] + btn_list + [random_btn] - + [slow_warning], + + [slow_warning] + + [show_vote_buttons], ).then( bot_response_multi, states + [temperature, top_p, max_output_tokens], states + chatbots + btn_list, ).then( flash_buttons, - [], + [show_vote_buttons], btn_list, ) @@ -650,14 +695,15 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None): + [multimodal_textbox, textbox, send_btn] + btn_list + [random_btn] - + [slow_warning], + + [slow_warning] + + [show_vote_buttons], ).then( bot_response_multi, states + [temperature, top_p, max_output_tokens], states + chatbots + btn_list, ).then( flash_buttons, - [], + [show_vote_buttons], btn_list, ) diff --git a/fastchat/serve/gradio_block_arena_vision_named.py b/fastchat/serve/gradio_block_arena_vision_named.py index 7c653acf3..65b283d4a 100644 --- a/fastchat/serve/gradio_block_arena_vision_named.py +++ b/fastchat/serve/gradio_block_arena_vision_named.py @@ -32,15 +32,19 @@ set_invisible_image, set_visible_image, add_image, - moderate_input, _prepare_text_with_image, convert_images_to_conversation_format, - enable_multimodal, + enable_multimodal_keep_input, + enable_multimodal_clear_input, disable_multimodal, invisible_text, invisible_btn, visible_text, ) +from fastchat.serve.moderation.moderator import ( + BaseContentModerator, + AzureAndOpenAIContentModerator, +) from fastchat.serve.gradio_global_state import Context from fastchat.serve.gradio_web_server import ( State, @@ -54,12 +58,12 @@ get_ip, get_model_description_md, enable_text, + show_vote_button, + dont_show_vote_button, ) from fastchat.serve.remote_logger import get_remote_logger from fastchat.utils import ( build_logger, - moderation_filter, - image_moderation_filter, ) @@ -83,7 +87,7 @@ def load_demo_side_by_side_vision_named(context: Context): else: model_right = model_left - all_models = context.models + all_models = list(set(context.text_models + context.vision_models)) selector_updates = [ gr.Dropdown(choices=all_models, value=model_left, visible=True), gr.Dropdown(choices=all_models, value=model_right, visible=True), @@ -97,7 +101,7 @@ def clear_history_example(request: gr.Request): return ( [None] * num_sides + [None] * num_sides - + [enable_multimodal, invisible_text, invisible_btn] + + [enable_multimodal_keep_input, invisible_text, invisible_btn] + [invisible_btn] * 4 + [disable_btn] * 2 ) @@ -163,6 +167,7 @@ def regenerate(state0, state1, request: gr.Request): if state0.regen_support and state1.regen_support: for i in range(num_sides): states[i].conv.update_last_message(None) + states[i].content_moderator.update_last_moderation_response(None) return ( states + [x.to_gradio_chatbot() for x in states] @@ -181,7 +186,7 @@ def clear_history(request: gr.Request): return ( [None] * num_sides + [None] * num_sides - + [enable_multimodal, invisible_text, invisible_btn] + + [enable_multimodal_clear_input, invisible_text, invisible_btn] + [invisible_btn] * 4 + [disable_btn] * 2 ) @@ -238,6 +243,7 @@ def add_text( no_change_btn, ] * 6 + + [dont_show_vote_button] ) model_list = [states[i].model_name for i in range(num_sides)] @@ -249,10 +255,30 @@ def add_text( images = convert_images_to_conversation_format(images) - text, image_flagged, csam_flag = moderate_input( - state0, text, all_conv_text, model_list, images, ip + # Use the first state to get the moderation response because this is based on user input so it is independent of the model + moderation_image_input = images[0] if len(images) > 0 else None + moderation_type_to_response_map = states[ + 0 + ].content_moderator.image_and_text_moderation_filter( + moderation_image_input, text, model_list, do_moderation=False + ) + text_flagged, nsfw_flag, csam_flag = ( + states[0].content_moderator.text_flagged, + states[0].content_moderator.nsfw_flagged, + states[0].content_moderator.csam_flagged, ) + if csam_flag: + states[0].has_csam_image, states[1].has_csam_image = True, True + + for state in states: + state.content_moderator.append_moderation_response( + moderation_type_to_response_map + ) + + if text_flagged or nsfw_flag: + logger.info(f"violate moderation. ip: {ip}. text: {text}") + conv = states[0].conv if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: logger.info(f"conversation turn limit. ip: {ip}. text: {text}") @@ -266,26 +292,34 @@ def add_text( no_change_btn, ] * 6 + + [dont_show_vote_button] ) - if image_flagged: - logger.info(f"image flagged. ip: {ip}. text: {text}") + if text_flagged or nsfw_flag: + logger.info(f"violate moderation. ip: {ip}. text: {text}") + gradio_chatbot_list = [x.to_gradio_chatbot() for x in states] for i in range(num_sides): + post_processed_text = _prepare_text_with_image(states[i], text, images) + states[i].conv.append_message(states[i].conv.roles[0], post_processed_text) states[i].skip_next = True + gr.Warning(MODERATION_MSG) return ( states - + [x.to_gradio_chatbot() for x in states] - + [{"text": IMAGE_MODERATION_MSG}, "", no_change_btn] + + gradio_chatbot_list + + [None, "", no_change_btn] + [ no_change_btn, ] * 6 + + [dont_show_vote_button] ) text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off for i in range(num_sides): post_processed_text = _prepare_text_with_image( - states[i], text, images, csam_flag=csam_flag + states[i], + text, + images, ) states[i].conv.append_message(states[i].conv.roles[0], post_processed_text) states[i].conv.append_message(states[i].conv.roles[1], None) @@ -299,6 +333,7 @@ def add_text( disable_btn, ] * 6 + + [show_vote_button] ) @@ -323,10 +358,11 @@ def build_side_by_side_vision_ui_named(context: Context, random_questions=None): states = [gr.State() for _ in range(num_sides)] model_selectors = [None] * num_sides chatbots = [None] * num_sides + show_vote_buttons = gr.State(True) notice = gr.Markdown(notice_markdown, elem_id="notice_markdown") - text_and_vision_models = context.models + text_and_vision_models = list(set(context.text_models + context.vision_models)) context_state = gr.State(context) with gr.Row(): @@ -534,37 +570,49 @@ def build_side_by_side_vision_ui_named(context: Context, random_questions=None): multimodal_textbox.submit( add_text, states + model_selectors + [multimodal_textbox, context_state], - states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list, + states + + chatbots + + [multimodal_textbox, textbox, send_btn] + + btn_list + + [show_vote_buttons], ).then(set_invisible_image, [], [image_column]).then( bot_response_multi, states + [temperature, top_p, max_output_tokens], states + chatbots + btn_list, ).then( - flash_buttons, [], btn_list + flash_buttons, [show_vote_buttons], btn_list ) textbox.submit( add_text, states + model_selectors + [textbox, context_state], - states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list, + states + + chatbots + + [multimodal_textbox, textbox, send_btn] + + btn_list + + [show_vote_buttons], ).then(set_invisible_image, [], [image_column]).then( bot_response_multi, states + [temperature, top_p, max_output_tokens], states + chatbots + btn_list, ).then( - flash_buttons, [], btn_list + flash_buttons, [show_vote_buttons], btn_list ) send_btn.click( add_text, states + model_selectors + [textbox, context_state], - states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list, + states + + chatbots + + [multimodal_textbox, textbox, send_btn] + + btn_list + + [show_vote_buttons], ).then(set_invisible_image, [], [image_column]).then( bot_response_multi, states + [temperature, top_p, max_output_tokens], states + chatbots + btn_list, ).then( - flash_buttons, [], btn_list + flash_buttons, [show_vote_buttons], btn_list ) if random_questions: diff --git a/fastchat/serve/gradio_global_state.py b/fastchat/serve/gradio_global_state.py index fafaec213..e05022f18 100644 --- a/fastchat/serve/gradio_global_state.py +++ b/fastchat/serve/gradio_global_state.py @@ -8,5 +8,3 @@ class Context: all_text_models: List[str] = field(default_factory=list) vision_models: List[str] = field(default_factory=list) all_vision_models: List[str] = field(default_factory=list) - models: List[str] = field(default_factory=list) - all_models: List[str] = field(default_factory=list) diff --git a/fastchat/serve/gradio_web_server.py b/fastchat/serve/gradio_web_server.py index 4f0521da0..7a0863363 100644 --- a/fastchat/serve/gradio_web_server.py +++ b/fastchat/serve/gradio_web_server.py @@ -34,13 +34,13 @@ ) from fastchat.model.model_registry import get_model_info, model_info from fastchat.serve.api_provider import get_api_provider_stream_iter +from fastchat.serve.moderation.moderator import AzureAndOpenAIContentModerator from fastchat.serve.gradio_global_state import Context from fastchat.serve.remote_logger import get_remote_logger from fastchat.utils import ( build_logger, get_window_url_params_js, get_window_url_params_with_tos_js, - moderation_filter, parse_gradio_auth_creds, load_image, ) @@ -61,6 +61,8 @@ visible=True, placeholder='Press "๐ŸŽฒ New Round" to start over๐Ÿ‘‡ (Note: Your vote shapes the leaderboard, please vote RESPONSIBLY!)', ) +show_vote_button = True +dont_show_vote_button = False controller_url = None enable_moderation = False @@ -119,6 +121,7 @@ def __init__(self, model_name, is_vision=False): self.model_name = model_name self.oai_thread_id = None self.is_vision = is_vision + self.content_moderator = AzureAndOpenAIContentModerator() # NOTE(chris): This could be sort of a hack since it assumes the user only uploads one image. If they can upload multiple, we should store a list of image hashes. self.has_csam_image = False @@ -151,6 +154,7 @@ def dict(self): { "conv_id": self.conv_id, "model_name": self.model_name, + "moderation": self.content_moderator.conv_moderation_responses, } ) @@ -240,8 +244,7 @@ def load_demo_single(context: Context, query_params): if model in models: selected_model = model - all_models = context.models - + all_models = list(set(context.text_models + context.vision_models)) dropdown_update = gr.Dropdown( choices=all_models, value=selected_model, visible=True ) @@ -308,6 +311,7 @@ def regenerate(state, request: gr.Request): state.skip_next = True return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 state.conv.update_last_message(None) + state.content_moderator.update_last_moderation_response(None) return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 @@ -341,14 +345,24 @@ def add_text(state, model_selector, text, request: gr.Request): state.skip_next = True return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 - all_conv_text = state.conv.get_prompt() - all_conv_text = all_conv_text[-2000:] + "\nuser: " + text - flagged = moderation_filter(all_conv_text, [state.model_name]) - # flagged = moderation_filter(text, [state.model_name]) - if flagged: + content_moderator = AzureAndOpenAIContentModerator() + text_flagged = content_moderator.text_moderation_filter(text, [state.model_name]) + + if text_flagged: logger.info(f"violate moderation. ip: {ip}. text: {text}") # overwrite the original text - text = MODERATION_MSG + content_moderator.write_to_json(get_ip(request)) + state.skip_next = True + gr.Warning(MODERATION_MSG) + return ( + [state] + + [state.to_gradio_chatbot()] + + [""] + + [ + no_change_btn, + ] + * 5 + ) if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: logger.info(f"conversation turn limit. ip: {ip}. text: {text}") @@ -419,6 +433,36 @@ def is_limit_reached(model_name, ip): return None +def _write_to_json( + filename: str, + start_tstamp: float, + finish_tstamp: float, + state: State, + temperature: float, + top_p: float, + max_new_tokens: int, + request: gr.Request, +): + with open(filename, "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": state.model_name, + "gen_params": { + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + }, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + + return data + + def bot_response( state, temperature, @@ -438,6 +482,32 @@ def bot_response( if state.skip_next: # This generate call is skipped due to invalid inputs state.skip_next = False + if state.content_moderator.text_flagged or state.content_moderator.nsfw_flagged: + start_tstamp = time.time() + finish_tstamp = start_tstamp + state.conv.save_new_images( + has_csam_images=state.has_csam_image, + use_remote_storage=use_remote_storage, + ) + + filename = get_conv_log_filename( + is_vision=state.is_vision, has_csam_image=state.has_csam_image + ) + + _write_to_json( + filename, + start_tstamp, + finish_tstamp, + state, + temperature, + top_p, + max_new_tokens, + request, + ) + + # Remove the last message: the user input + state.conv.messages.pop() + state.content_moderator.update_last_moderation_response(None) yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 return @@ -589,22 +659,23 @@ def bot_response( is_vision=state.is_vision, has_csam_image=state.has_csam_image ) - with open(filename, "a") as fout: - data = { - "tstamp": round(finish_tstamp, 4), - "type": "chat", - "model": model_name, - "gen_params": { - "temperature": temperature, - "top_p": top_p, - "max_new_tokens": max_new_tokens, - }, - "start": round(start_tstamp, 4), - "finish": round(finish_tstamp, 4), - "state": state.dict(), - "ip": get_ip(request), - } - fout.write(json.dumps(data) + "\n") + moderation_type_to_response_map = ( + state.content_moderator.image_and_text_moderation_filter( + None, output, [state.model_name], do_moderation=True + ) + ) + state.content_moderator.append_moderation_response(moderation_type_to_response_map) + + data = _write_to_json( + filename, + start_tstamp, + finish_tstamp, + state, + temperature, + top_p, + max_new_tokens, + request, + ) get_remote_logger().log(data) diff --git a/fastchat/serve/gradio_web_server_multi.py b/fastchat/serve/gradio_web_server_multi.py index 7a255d59e..f6342f2ec 100644 --- a/fastchat/serve/gradio_web_server_multi.py +++ b/fastchat/serve/gradio_web_server_multi.py @@ -310,8 +310,8 @@ def build_demo( # Set global variables set_global_vars(args.controller_url, args.moderate, args.use_remote_storage) - set_global_vars_named(args.moderate) - set_global_vars_anony(args.moderate) + set_global_vars_named(args.moderate, args.use_remote_storage) + set_global_vars_anony(args.moderate, args.use_remote_storage) text_models, all_text_models = get_model_list( args.controller_url, args.register_api_endpoint_file, @@ -324,20 +324,7 @@ def build_demo( vision_arena=True, ) - models = text_models + [ - model for model in vision_models if model not in text_models - ] - all_models = all_text_models + [ - model for model in all_vision_models if model not in all_text_models - ] - context = Context( - text_models, - all_text_models, - vision_models, - all_vision_models, - models, - all_models, - ) + context = Context(text_models, all_text_models, vision_models, all_vision_models) # Set authorization credentials auth = None diff --git a/fastchat/serve/moderation/moderator.py b/fastchat/serve/moderation/moderator.py new file mode 100644 index 000000000..41751d29a --- /dev/null +++ b/fastchat/serve/moderation/moderator.py @@ -0,0 +1,273 @@ +import datetime +import hashlib +import os +import json +import time +import base64 +import requests +from typing import Tuple, Dict, List, Union + +from fastchat.constants import LOGDIR +from fastchat.serve.vision.image import Image +from fastchat.utils import load_image, upload_image_file_to_gcs + + +class BaseContentModerator: + def __init__(self): + self.conv_moderation_responses: List[ + Dict[str, Dict[str, Union[str, Dict[str, float]]]] + ] = [] + self.text_flagged = False + self.csam_flagged = False + self.nsfw_flagged = False + + def _image_moderation_filter(self, image: Image) -> Tuple[bool, bool]: + """Function that detects whether image violates moderation policies. + + Returns: + Tuple[bool, bool]: A tuple of two boolean values indicating whether the image was flagged for nsfw and csam respectively. + """ + raise NotImplementedError + + def _text_moderation_filter(self, text: str) -> bool: + """Function that detects whether text violates moderation policies. + + Returns: + bool: A boolean value indicating whether the text was flagged. + """ + raise NotImplementedError + + def reset_moderation_flags(self): + self.text_flagged = False + self.csam_flagged = False + self.nsfw_flagged = False + + def image_and_text_moderation_filter( + self, image: Image, text: str + ) -> Dict[str, Dict[str, Union[str, Dict[str, float]]]]: + """Function that detects whether image and text violate moderation policies. + + Returns: + Dict[str, Dict[str, Union[str, Dict[str, float]]]]: A dictionary that maps the type of moderation (text, nsfw, csam) to a dictionary that contains the moderation response. + """ + raise NotImplementedError + + def append_moderation_response( + self, moderation_response: Dict[str, Dict[str, Union[str, Dict[str, float]]]] + ): + """Function that appends the moderation response to the list of moderation responses.""" + if ( + len(self.conv_moderation_responses) == 0 + or self.conv_moderation_responses[-1] is not None + ): + self.conv_moderation_responses.append(moderation_response) + else: + self.update_last_moderation_response(moderation_response) + + def update_last_moderation_response( + self, moderation_response: Dict[str, Dict[str, Union[str, Dict[str, float]]]] + ): + """Function that updates the last moderation response.""" + self.conv_moderation_responses[-1] = moderation_response + + +class AzureAndOpenAIContentModerator(BaseContentModerator): + _NON_TOXIC_IMAGE_MODERATION_MAP = { + "nsfw_moderation": {"flagged": False}, + "csam_moderation": {"flagged": False}, + } + + def __init__(self, use_remote_storage: bool = False): + """This class is used to moderate content using Azure and OpenAI. + + conv_to_moderation_responses: A dictionary that is a map from the type of moderation + (text, nsfw, csam) moderation to the moderation response returned from the request sent + to the moderation API. + """ + super().__init__() + + def _image_moderation_request( + self, image_bytes: bytes, endpoint: str, api_key: str + ) -> dict: + headers = {"Content-Type": "image/jpeg", "Ocp-Apim-Subscription-Key": api_key} + + MAX_RETRIES = 3 + for _ in range(MAX_RETRIES): + response = requests.post(endpoint, headers=headers, data=image_bytes).json() + try: + if response["Status"]["Code"] == 3000: + break + except: + time.sleep(0.5) + return response + + def _image_moderation_provider(self, image_bytes: bytes, api_type: str) -> bool: + if api_type == "nsfw": + endpoint = os.environ["AZURE_IMG_MODERATION_ENDPOINT"] + api_key = os.environ["AZURE_IMG_MODERATION_API_KEY"] + response = self._image_moderation_request(image_bytes, endpoint, api_key) + flagged = response["IsImageAdultClassified"] + elif api_type == "csam": + endpoint = ( + "https://api.microsoftmoderator.com/photodna/v1.0/Match?enhance=false" + ) + api_key = os.environ["PHOTODNA_API_KEY"] + response = self._image_moderation_request(image_bytes, endpoint, api_key) + flagged = response["IsMatch"] + + image_md5_hash = hashlib.md5(image_bytes).hexdigest() + moderation_response_map = { + "image_hash": image_md5_hash, + "response": response, + "flagged": False, + } + if flagged: + moderation_response_map["flagged"] = True + + return moderation_response_map + + def image_moderation_filter(self, image: Image): + print(f"moderating image") + + image_bytes = base64.b64decode(image.base64_str) + + nsfw_flagged_map = self._image_moderation_provider(image_bytes, "nsfw") + + if nsfw_flagged_map["flagged"]: + csam_flagged_map = self._image_moderation_provider(image_bytes, "csam") + else: + csam_flagged_map = {"flagged": False} + + self.nsfw_flagged = nsfw_flagged_map["flagged"] + self.csam_flagged = csam_flagged_map["flagged"] + + # We save only the boolean value instead of the entire response dictionary + # to save space. nsfw_flagged_map and csam_flagged_map will contain the whole dictionary + return { + "nsfw_moderation": {"flagged": self.nsfw_flagged}, + "csam_moderation": {"flagged": self.csam_flagged}, + } + + def _openai_moderation_filter( + self, text: str, custom_thresholds: dict = None + ) -> bool: + """ + Check whether the text violates OpenAI moderation API. + """ + import openai + + client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"]) + + # default to true to be conservative + flagged = True + MAX_RETRY = 3 + moderation_response_map = {"content": text, "response": None, "flagged": False} + for _ in range(MAX_RETRY): + try: + res = client.moderations.create(input=text) + flagged = res.results[0].flagged + if custom_thresholds is not None: + for category, threshold in custom_thresholds.items(): + if ( + getattr(res.results[0].category_scores, category) + > threshold + ): + flagged = True + moderation_response_map = { + "response": dict(res.results[0].category_scores), + "flagged": flagged, + } + break + except (openai.OpenAIError, KeyError, IndexError) as e: + print(f"MODERATION ERROR: {e}\nInput: {text}") + + return moderation_response_map + + def text_moderation_filter( + self, text: str, model_list: List[str], do_moderation: bool = False + ): + # Apply moderation for below models + MODEL_KEYWORDS = [ + "claude", + "gpt", + "bard", + "mistral-large", + "command-r", + "dbrx", + "gemini", + "reka", + ] + + custom_thresholds = {"sexual": 0.3} + # set a stricter threshold for claude + for model in model_list: + if "claude" in model: + custom_thresholds = {"sexual": 0.2} + + for keyword in MODEL_KEYWORDS: + for model in model_list: + if keyword in model: + do_moderation = True + break + + moderation_response_map = {"flagged": False} + if do_moderation: + # We save the entire response dictionary here + moderation_response_map = self._openai_moderation_filter( + text, custom_thresholds + ) + self.text_flagged = moderation_response_map["flagged"] + else: + self.text_flagged = False + + # We only save whether the text was flagged or not instead of the entire response dictionary + # to save space. moderation_response_map will contain the whole dictionary + return {"text_moderation": {"flagged": self.text_flagged}} + + def image_and_text_moderation_filter( + self, image: Image, text: str, model_list: List[str], do_moderation=True + ) -> Dict[str, bool]: + """Function that detects whether image and text violate moderation policies using the Azure and OpenAI moderation APIs. + + Returns: + Dict[str, Dict[str, Union[str, Dict[str, float]]]]: A dictionary that maps the type of moderation (text, nsfw, csam) to a dictionary that contains the moderation response. + + Example: + { + "text_moderation": { + "content": "This is a test", + "response": { + "sexual": 0.1 + }, + "flagged": True + }, + "nsfw_moderation": { + "image_hash": "1234567890", + "response": { + "IsImageAdultClassified": True + }, + "flagged": True + }, + "csam_moderation": { + "image_hash": "1234567890", + "response": { + "IsMatch": True + }, + "flagged": True + } + } + """ + print("moderating text: ", text) + self.reset_moderation_flags() + text_flagged_map = self.text_moderation_filter(text, model_list, do_moderation) + + if image is not None: + image_flagged_map = self.image_moderation_filter(image) + else: + image_flagged_map = self._NON_TOXIC_IMAGE_MODERATION_MAP + + res = {} + res.update(text_flagged_map) + res.update(image_flagged_map) + + return res diff --git a/fastchat/serve/monitor/monitor.py b/fastchat/serve/monitor/monitor.py index c07ee4669..346281598 100644 --- a/fastchat/serve/monitor/monitor.py +++ b/fastchat/serve/monitor/monitor.py @@ -1011,7 +1011,6 @@ def build_leaderboard_tab( "avg_tokens": "Average Tokens", } ) - model_to_score = {} for i in range(len(dataFrame)): model_to_score[dataFrame.loc[i, "Model"]] = dataFrame.loc[ i, "Win-rate" diff --git a/fastchat/utils.py b/fastchat/utils.py index d3531928f..4ec8249e1 100644 --- a/fastchat/utils.py +++ b/fastchat/utils.py @@ -149,62 +149,6 @@ def get_gpu_memory(max_gpus=None): return gpu_memory -def oai_moderation(text, custom_thresholds=None): - """ - Check whether the text violates OpenAI moderation API. - """ - import openai - - client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"]) - - # default to true to be conservative - flagged = True - MAX_RETRY = 3 - for _ in range(MAX_RETRY): - try: - res = client.moderations.create(input=text) - flagged = res.results[0].flagged - if custom_thresholds is not None: - for category, threshold in custom_thresholds.items(): - if getattr(res.results[0].category_scores, category) > threshold: - flagged = True - break - except (openai.OpenAIError, KeyError, IndexError) as e: - print(f"MODERATION ERROR: {e}\nInput: {text}") - return flagged - - -def moderation_filter(text, model_list, do_moderation=False): - # Apply moderation for below models - MODEL_KEYWORDS = [ - "claude", - "gpt", - "bard", - "mistral-large", - "command-r", - "dbrx", - "gemini", - "reka", - "eureka", - ] - - custom_thresholds = {"sexual": 0.3} - # set a stricter threshold for claude - for model in model_list: - if "claude" in model: - custom_thresholds = {"sexual": 0.2} - - for keyword in MODEL_KEYWORDS: - for model in model_list: - if keyword in model: - do_moderation = True - break - - if do_moderation: - return oai_moderation(text, custom_thresholds) - return False - - def clean_flant5_ckpt(ckpt_path): """ Flan-t5 trained with HF+FSDP saves corrupted weights for shared embeddings, @@ -439,47 +383,3 @@ def get_image_file_from_gcs(filename): contents = blob.download_as_bytes() return contents - - -def image_moderation_request(image_bytes, endpoint, api_key): - headers = {"Content-Type": "image/jpeg", "Ocp-Apim-Subscription-Key": api_key} - - MAX_RETRIES = 3 - for _ in range(MAX_RETRIES): - response = requests.post(endpoint, headers=headers, data=image_bytes).json() - try: - if response["Status"]["Code"] == 3000: - break - except: - time.sleep(0.5) - return response - - -def image_moderation_provider(image, api_type): - if api_type == "nsfw": - endpoint = os.environ["AZURE_IMG_MODERATION_ENDPOINT"] - api_key = os.environ["AZURE_IMG_MODERATION_API_KEY"] - response = image_moderation_request(image, endpoint, api_key) - print(response) - return response["IsImageAdultClassified"] - elif api_type == "csam": - endpoint = ( - "https://api.microsoftmoderator.com/photodna/v1.0/Match?enhance=false" - ) - api_key = os.environ["PHOTODNA_API_KEY"] - response = image_moderation_request(image, endpoint, api_key) - return response["IsMatch"] - - -def image_moderation_filter(image): - print(f"moderating image") - - image_bytes = base64.b64decode(image.base64_str) - - nsfw_flagged = image_moderation_provider(image_bytes, "nsfw") - csam_flagged = False - - if nsfw_flagged: - csam_flagged = image_moderation_provider(image_bytes, "csam") - - return nsfw_flagged, csam_flagged