diff --git a/ai_eval/shortanswer.py b/ai_eval/shortanswer.py
index cb0e5a3..7918c8f 100644
--- a/ai_eval/shortanswer.py
+++ b/ai_eval/shortanswer.py
@@ -1,11 +1,15 @@
"""Short answers Xblock with AI evaluation."""
+import email
import json
import logging
import traceback
+from xml.sax import saxutils
from django.utils.translation import gettext_noop as _
from web_fragments.fragment import Fragment
+from webob import Response
+from webob.exc import HTTPForbidden, HTTPNotFound
from xblock.core import XBlock
from xblock.exceptions import JsonHandlerError
from xblock.fields import Boolean, Dict, Integer, String, Scope
@@ -18,6 +22,7 @@
logger = logging.getLogger(__name__)
+@XBlock.wants('studio_user_permissions')
class ShortAnswerAIEvalXBlock(AIEvalXBlock):
"""
Short Answer Xblock.
@@ -110,11 +115,10 @@ def get_response(self, data, suffix=""): # pylint: disable=unused-argument
attachments = []
for filename, contents in self.attachments.items():
- # TODO: escape
attachments.append(f"""
- {filename}
- {contents}
+ {saxutils.escape(filename)}
+ {saxutils.escape(contents)}
""")
attachments = '\n'.join(attachments)
@@ -203,6 +207,27 @@ def submit_studio_edits(self, data, suffix=''):
data["values"]["attachments"][key] = self.attachments[key]
return super().submit_studio_edits.__wrapped__(self, data, suffix)
+ @XBlock.handler
+ def view_attachment(self, request, suffix=''):
+ user_perms = self.runtime.service(self, 'studio_user_permissions')
+ if not (user_perms and user_perms.can_read(self.scope_ids.usage_id.context_key)):
+ return request.get_response(HTTPForbidden())
+
+ key = request.GET['key']
+ try:
+ data = self.attachments[key]
+ except KeyError:
+ return request.get_response(HTTPNotFound())
+
+ escaped = key.replace("\\", "\\\\").replace('"', '\\"')
+ return Response(
+ body=data.encode(),
+ headerlist=[
+ ("Content-Type", "application/octet-stream"),
+ ("Content-Disposition", f'attachment; filename="{escaped}"'),
+ ]
+ )
+
@staticmethod
def workbench_scenarios():
"""A canned scenario for display in the workbench."""
diff --git a/ai_eval/static/js/src/shortanswer_edit.js b/ai_eval/static/js/src/shortanswer_edit.js
index d6e8ebd..47261aa 100644
--- a/ai_eval/static/js/src/shortanswer_edit.js
+++ b/ai_eval/static/js/src/shortanswer_edit.js
@@ -4,6 +4,8 @@ function ShortAnswerAIEvalXBlock(runtime, element) {
StudioEditableXBlockMixin(runtime, element);
+ var viewAttachmentUrl = runtime.handlerUrl(element, "view_attachment");
+
var $input = $('#xb-field-edit-attachments');
var buildFileInput = function() {
@@ -13,7 +15,18 @@ function ShortAnswerAIEvalXBlock(runtime, element) {
var files = JSON.parse($input.val() || "{}");
for (var filename of Object.keys(files)) {
var $fileItem = $('
');
- $fileItem.append(filename);
+ var $fileLink;
+ if (files[filename] === null) {
+ /* File that already exists. */
+ $fileLink = $('');
+ $fileLink.attr("target", "_blank");
+ var fileLinkQuery = new URLSearchParams({key: filename}).toString();
+ $fileLink.attr('href', `${viewAttachmentUrl}?${fileLinkQuery}`);
+ } else {
+ $fileLink = $('');
+ }
+ $fileLink.append(filename);
+ $fileItem.append($fileLink);
var $deleteButton = $('');
$deleteButton.append($(''));
var $deleteButtonText = $('');
diff --git a/ai_eval/tests/test_ai_eval.py b/ai_eval/tests/test_ai_eval.py
index da5d4ea..692e34c 100644
--- a/ai_eval/tests/test_ai_eval.py
+++ b/ai_eval/tests/test_ai_eval.py
@@ -3,6 +3,7 @@
"""
import unittest
+from unittest.mock import patch
from xblock.exceptions import JsonHandlerError
from xblock.field_data import DictFieldData
from xblock.test.toy_runtime import ToyRuntime
@@ -85,3 +86,18 @@ def test_reset_forbidden(self):
with self.assertRaises(JsonHandlerError):
block.reset.__wrapped__(block, data={})
self.assertEqual(block.messages, {"USER": ["Hello"], "LLM": ["Hello"]})
+
+ @patch('ai_eval.shortanswer.get_llm_response')
+ def test_attachments(self, get_llm_response):
+ """Test the attachments."""
+ data = {
+ **self.data,
+ "attachments": {"test.json": '{"test": "test"}'},
+ }
+ block = ShortAnswerAIEvalXBlock(ToyRuntime(), DictFieldData(data), None)
+ get_llm_response.return_value = "Hello"
+ block.get_response.__wrapped__(block, data={"user_input": "Hello"})
+ system_msg = get_llm_response.call_args[2][0]["content"]
+ self.assertIn("test.json", system_msg)
+ self.assertIn('{"test": "test"}', system_msg)
+ self.assertEqual(block.messages, {"USER": ["Hello"], "LLM": ["Hello"]})