diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index ff889b6..58351a0 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -106,13 +106,13 @@ jobs: auditwheel repair dist/*-${{ matrix.cp }}-linux_*.whl TWINE_USERNAME=${{ secrets.PYPI_USERNAME }} TWINE_PASSWORD=${{ secrets.PYPI_PASSWORD }} /opt/python/${{ matrix.cp }}/bin/python -m twine upload wheelhouse/*.whl - build_macos_11: + build_macos_13: name: Build for macOS 13 runs-on: macOS-13 strategy: max-parallel: 4 matrix: - python-version: [3.8, 3.9, "3.10", 3.11] + python-version: [3.8, 3.9, "3.10", 3.11, 3.12] cpu-arch: [x86_64, arm64] steps: @@ -123,7 +123,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - python -m pip install --upgrade pip + python -m pip install --upgrade pip setuptools wget https://gitlab.com/libeigen/eigen/-/archive/3.4.0/eigen-3.4.0.tar.gz tar -zxvf eigen-3.4.0.tar.gz rm eigen-3.4.0.tar.gz diff --git a/.github/workflows/deploy_test.yml b/.github/workflows/deploy_test.yml index 157c2b9..0eb6efe 100644 --- a/.github/workflows/deploy_test.yml +++ b/.github/workflows/deploy_test.yml @@ -105,13 +105,13 @@ jobs: auditwheel repair dist/*-${{ matrix.cp }}-linux_*.whl TWINE_USERNAME=${{ secrets.TEST_PYPI_USERNAME }} TWINE_PASSWORD=${{ secrets.TEST_PYPI_PASSWORD }} /opt/python/${{ matrix.cp }}/bin/python -m twine upload --repository testpypi wheelhouse/*.whl - build_macos_11: + build_macos_13: name: Build for macOS 13 runs-on: macOS-13 strategy: max-parallel: 4 matrix: - python-version: [3.8, 3.9, "3.10", 3.11] + python-version: [3.8, 3.9, "3.10", 3.11, 3.12] cpu-arch: [x86_64, arm64] steps: @@ -122,7 +122,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - python -m pip install --upgrade pip + python -m pip install --upgrade pip setuptools wget https://gitlab.com/libeigen/eigen/-/archive/3.4.0/eigen-3.4.0.tar.gz tar -zxvf eigen-3.4.0.tar.gz rm eigen-3.4.0.tar.gz diff --git a/.github/workflows/generate_documentation.yml b/.github/workflows/generate_documentation.yml index 9faa18e..c6a792b 100644 --- a/.github/workflows/generate_documentation.yml +++ b/.github/workflows/generate_documentation.yml @@ -41,15 +41,15 @@ jobs: mv variant/include/mapbox include/ - name: build run: | - /opt/python/cp38-cp38/bin/python -m pip install numpy==`/opt/python/cp38-cp38/bin/python .github/workflows/numpy_version.py` - /opt/python/cp38-cp38/bin/python -m pip install pdoc3==0.8.4 + /opt/python/cp39-cp39/bin/python -m pip install numpy==`/opt/python/cp39-cp39/bin/python .github/workflows/numpy_version.py` + /opt/python/cp39-cp3/bin/python -m pip install pdoc3==0.8.4 export TOMOTOPY_LANG=${{ matrix.language }} - /opt/python/cp38-cp38/bin/python setup.py install + /opt/python/cp39-cp39/bin/python setup.py install - name: gen doc run: | - export TOMOTOPY_VER="`/opt/python/cp38-cp38/bin/python -m pip show tomotopy | grep Version | cut -d' ' -f2`" + export TOMOTOPY_VER="`/opt/python/cp39-cp39/bin/python -m pip show tomotopy | grep Version | cut -d' ' -f2`" export TOMOTOPY_LANG=${{ matrix.language }} - /opt/python/cp38-cp38/bin/python -m pdoc --html tomotopy + /opt/python/cp39-cp39/bin/python -m pdoc --html tomotopy sed -i -E "s/documentation<\/title>/documentation (v${TOMOTOPY_VER})<\/title>/" html/tomotopy/*.html sed -i -E 's/<\/title>/<\/title>/' html/tomotopy/*.html sed -i -E 's/(

<\/p>)/

+

Top words in each topic

-
- More words... + + +
+ {% set col = "col-md-6" if top1_topic_dist_by_metadata else "" %} +
+

Distribution of Document Top1 Topics

+ + +
+ {% if top1_topic_dist_by_metadata %} +
+

Distribution of Document Top1 Topics by Metadata

+ + +
+ {% end %} +
+
+ {% elif action == 'topic-rel' %} +
+
+
+

Word overlap between topics

+
+ + + + {% for i in range(len(overlaps)) %} + + {% end %} + + {% for j in range(len(overlaps)) %} + + + {% for i in range(len(overlaps)) %} + + {% end %} + + {% end %} +
{{i}}
{{j}}
+
+
+
+ {% for n, (i, j) in enumerate(similar_pairs) %} + = len(overlaps) else ""}}"> + + + + + + + +
{{get_topic_label(i, id_suffix=True)}} - {{get_topic_label(j, id_suffix=True)}}: {{overlaps[i, j]:.2%}}
{{get_topic_label(i, prefix="Topic ", id_suffix=True)}}: + {% for word, p in model.get_topic_words(i, top_n=10) %} + {{word}}({{p:.2%}}) + {% end %} + {{get_topic_label(j, prefix="Topic ", id_suffix=True)}}: + {% for word, p in model.get_topic_words(j, top_n=10) %} + {{word}}({{p:.2%}}) + {% end %} +
+ {% end %} +
+
+
+ {% elif action == 'metadata' %} + +
+
+
+ + + +
+ Confidence + +
+ + + +
+
+
+ {% for i, (s, e) in enumerate(numeric_metadata) %} +
+ Numeric Metadata #{{i}}: Range({{s}} - {{e}}) +
+ + + +
+ + - + + {% if i == len(numeric_metadata) - 1 %} + + {% end %} +
+ {% end %} + +
+ +
+ {% for topic in range(cats.shape[2]) %} +
+

{{get_topic_label(topic, prefix="Topic ", id_suffix=True)}}: {{", ".join(w for w, _ in model.get_topic_words(topic, 5))}}

+ +
+ {% end %} +
+ + + +
+ {% elif action == 'tdf-map' %} +
+

Topic Distribution Function Map

+
+
+ +
+ {% for topic in range(model.k) %} +
+

{{get_topic_label(topic, prefix="Topic ", id_suffix=True)}}

+ Map for {{get_topic_label(topic, prefix= +
+ {% end %} +
{% end %}
@@ -209,34 +556,62 @@

Topic #{{topic.topic_id} (function(){ const tooltip_list = [...document.querySelectorAll('[data-bs-toggle="tooltip"]')].map(e => new bootstrap.Tooltip(e)); + function view_document(doc_id) { + if (doc_id === null) { + // hide document view + document.getElementById('document-content').style.display = 'none'; + document.getElementById('document-list').classList.remove('d-none'); + return; + } + // remove active class + document.querySelectorAll('.doc-item').forEach((el) => { + el.classList.remove('active'); + }); + // add active class + document.getElementById(`doc-item-${doc_id}`).classList.add('active'); + // show loading + document.getElementById('document-content-title').innerHTML = `Doc ${doc_id}`; + document.getElementById('document-metadata').innerHTML = ''; + document.getElementById('document-content-body').innerHTML = '
Loading...
'; + document.getElementById('document-content').style.display = 'block'; + document.getElementById('document-list').classList.add('d-none'); + + fetch(`/api/document/${doc_id}`).then((res) => res.text()).then((html) => { + document.getElementById('document-content-body').innerHTML = html; + const tooltip_list = [...document.querySelectorAll('#document-content-body [data-bs-toggle="tooltip"]')].map(e => new bootstrap.Tooltip(e)); + document.querySelectorAll('#document-content-body meta').forEach((el) => { + const key = el.getAttribute('name'); + const value = el.getAttribute('content'); + document.getElementById('document-metadata').innerHTML += `${key}: ${value} `; + }); + }).catch((err) => { + document.getElementById('document-content-body').innerHTML = `

${err}

`; + }) + } + document.querySelectorAll('.document-view').forEach((el) => { el.addEventListener('click', (e) => { //e.preventDefault(); - const doc_id = el.getAttribute('href').substring(5); - // remove active class - document.querySelectorAll('.doc-item').forEach((el) => { - el.classList.remove('active'); - }); - // add active class - document.getElementById(`doc-item-${doc_id}`).classList.add('active'); - // show loading - document.getElementById('document-content-title').innerHTML = `Doc ${doc_id}`; - document.getElementById('document-content-body').innerHTML = '
Loading...
'; - document.getElementById('document-content').style.display = 'block'; - - fetch(`/api/document/${doc_id}`).then((res) => res.text()).then((html) => { - document.getElementById('document-content-body').innerHTML = html; - const tooltip_list = [...document.querySelectorAll('#document-content-body [data-bs-toggle="tooltip"]')].map(e => new bootstrap.Tooltip(e)); - }).catch((err) => { - document.getElementById('document-content-body').innerHTML = `

${err}

`; - }) + //const doc_id = el.getAttribute('href').substring(5); + //view_document(doc_id); }); }); if (document.querySelector('#document-content-close')) { document.querySelector('#document-content-close').addEventListener('click', (e) => { e.preventDefault(); - document.getElementById('document-content').style.display = 'none'; + //view_document(null); + location.hash = ''; + }); + + // when user go history back + window.addEventListener('popstate', (e) => { + // test if #doc-xxx exists + if (location.hash.startsWith('#doc-')) { + view_document(location.hash.substring(5)); + } else { + view_document(null); + } }); } @@ -266,6 +641,12 @@

Topic #{{topic.topic_id} }); } + if (document.querySelector('#document-filter-metadata')) { + document.querySelector('#document-filter-metadata').addEventListener('change', (e) => { + document.querySelector('#document-filter').submit(); + }); + } + // attach event handler to element with .topic-action dynamically document.addEventListener("click", function(e){ const target = e.target.closest(".topic-action"); @@ -282,6 +663,188 @@

Topic #{{topic.topic_id} }); }); + if (document.querySelector('.topic-title')) { + document.querySelectorAll('.topic-title').forEach((el) => { + const topic_id = el.getAttribute('data-id'); + el.querySelector('button').addEventListener('click', (e) => { + const button = el.querySelector('button'); + const label = el.querySelector('.topic-label'); + const input = el.querySelector('.topic-label-editable'); + button.style.display = 'none'; + label.style.display = 'none'; + input.style.display = ''; + input.value = label.textContent; + input.focus(); + }); + + function update_topic_label() { + const button = el.querySelector('button'); + const label = el.querySelector('.topic-label'); + const input = el.querySelector('.topic-label-editable'); + button.style.display = ''; + label.style.display = ''; + input.style.display = 'none'; + label.textContent = input.value; + fetch(`/api/topic/${topic_id}/label`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({label: input.value}), + }).then((res) => res.json()).then((json) => { + setTimeout(() => window.location.reload(), 50); + }); + } + el.querySelector('.topic-label-editable').addEventListener('blur', update_topic_label); + el.querySelector('.topic-label-editable').addEventListener('keydown', (e) => { + if (e.key === 'Enter') { + update_topic_label(); + } + }); + }); + } + + if (document.querySelector('#topic-overlap-table')) { + const num_topics = document.querySelectorAll('#topic-overlap-table tr').length - 1; + document.querySelectorAll('#topic-overlap-table td').forEach((el) => { + el.addEventListener('click', (e) => { + const topic_i = el.getAttribute('data-i'); + const topic_j = el.getAttribute('data-j'); + if (topic_i == '-1' || topic_j == '-1') return; + const list = document.querySelector('#most-similar-topic-pairs'); + [...list.children] + .sort((a, b) => { + const a_selected = a.getAttribute('data-j') == topic_j ? 1 : 0; + const b_selected = b.getAttribute('data-j') == topic_j ? 1 : 0; + if (b_selected > a_selected) return 1; + if (b_selected < a_selected) return -1; + return parseFloat(b.getAttribute('data-val')) - parseFloat(a.getAttribute('data-val')); + }) + .forEach((node, idx) => { + if (node.getAttribute('data-j') == topic_j) { + node.classList.add('active'); + } else { + node.classList.remove('active'); + } + node.style.display = node.getAttribute('data-i') == topic_i ? '' : 'none'; + list.appendChild(node); + }); + }); + }); + document.querySelector('#topic-pair-reset').addEventListener('click', (e) => { + const list = document.querySelector('#most-similar-topic-pairs'); + [...list.children] + .sort((a, b) => parseFloat(b.getAttribute('data-val')) - parseFloat(a.getAttribute('data-val'))) + .forEach((node, idx) => { + const is_triu = parseInt(node.getAttribute('data-i')) < parseInt(node.getAttribute('data-j')); + node.style.display = idx < (num_topics - 1) * 2 && is_triu ? '' : 'none'; + list.appendChild(node); + }); + }); + } + + function get_confidence_interval_data(p, done, cid = 0) { + if (typeof tdf_ci_lb[cid] !== 'undefined' && tdf_ci_p[cid] == p) { + if (cid + 1 < category_labels.length) { + get_confidence_interval_data(p, done, cid + 1); + } else { + done(); + } + return; + } + + fetch(`/api/conf-interval/${cid}/${p}`).then((res) => res.json()).then((json) => { + const cid = json.data.cid; + const p = json.data.p; + const lbs = json.data.lbs; + const ubs = json.data.ubs; + tdf_ci_p[cid] = p; + tdf_ci_lb[cid] = lbs; + tdf_ci_ub[cid] = ubs; + if (cid + 1 < category_labels.length) { + get_confidence_interval_data(p, done, cid + 1); + } else { + done(); + } + }); + } + + function updateConfInterval() { + const show_conf_interval = document.querySelector('#show-conf-interval').checked; + if (show_conf_interval) { + const p = parseFloat(document.querySelector('#confidence').value); + if (isNaN(p) || p <= 0 || p >= 1) { + alert('Confidence should be a number between 0 and 1.'); + return; + } + get_confidence_interval_data(p, () => { + for (var i in charts) { + const chart = charts[i]; + const data = chart.data; + const datasets = data.datasets; + datasets.splice(category_labels.length); + for (var cid in category_labels) { + const lb = tdf_ci_lb[cid][i]; + const ub = tdf_ci_ub[cid][i]; + for (var lv = 0; lv < 3; ++lv) { + const backgroundColor = datasets[cid].backgroundColor.replace(/, *0\.\d+\)/, lv > 0 ? ', 0.1)' : ', 0.15)'); + datasets.push({ + label: category_labels[cid] + ' LB', + data: lb.map((v, x) => (v * (3 - lv) + tdf_data[i][cid][x] * lv) / 3), + fill: cid, + backgroundColor: backgroundColor, + borderColor: 'transparent', + pointStyle: false, + }); + datasets.push({ + label: category_labels[cid] + ' UB', + data: ub.map((v, x) => (v * (3 - lv) + tdf_data[i][cid][x] * lv) / 3), + fill: cid, + backgroundColor: backgroundColor, + borderColor: 'transparent', + pointStyle: false, + }); + } + } + chart.update(); + } + }); + } else { + for (var i in charts) { + const chart = charts[i]; + const data = chart.data; + const datasets = data.datasets; + datasets.splice(category_labels.length); + chart.update(); + } + } + } + + if (document.querySelector('#show-conf-interval')) { + document.querySelector('#show-conf-interval').addEventListener('change', updateConfInterval); + document.querySelector('#confidence').addEventListener('change', updateConfInterval); + } + + if (document.querySelector('#opacity')) { + document.querySelector('#opacity').addEventListener('change', (e) => { + for (var i in charts) { + const chart = charts[i]; + chart.update(); + } + }); + } + + if (document.querySelector('#numeric-metadata-range')) { + document.querySelector('#numeric-metadata-range').addEventListener('submit', (e) => { + const range_values = []; + document.querySelectorAll('.range-value').forEach((el) => { + range_values.push(el.value); + }); + document.querySelector('#all-range-value').value = range_values.join(','); + return true; + }); + } + }()); diff --git a/tomotopy/viewer/viewer_server.py b/tomotopy/viewer/viewer_server.py index 7e5b766..6e62f14 100644 --- a/tomotopy/viewer/viewer_server.py +++ b/tomotopy/viewer/viewer_server.py @@ -3,6 +3,12 @@ import io import urllib.parse import html +import math +import json +import csv +import io +import functools +from collections import defaultdict, Counter from dataclasses import dataclass import traceback import http.server @@ -27,24 +33,157 @@ class Topic: topic_id: int words: list -def hue2rgb(h:float): +def hue2rgb(h:float, b:float = 1): h = h % 6.0 + b = min(max(b, 0), 1) if h < 1: - return (255, int(255 * h), 0) + return (int(255 * b), int(255 * h * b), 0) elif h < 2: - return (int(255 * (2 - h)), 255, 0) + return (int(255 * (2 - h) * b), int(255 * b), 0) elif h < 3: - return (0, 255, int(255 * (h - 2))) + return (0, int(255 * b), int(255 * (h - 2) * b)) elif h < 4: - return (0, int(255 * (4 - h)), 255) + return (0, int(255 * (4 - h) * b), int(255 * b)) elif h < 5: - return (int(255 * (h - 4)), 0, 255) + return (int(255 * (h - 4) * b), 0, int(255 * b)) + else: + return (int(255 * b), 0, int(255 * (6 - h) * b)) + +def scale_color(s:float, scale='log'): + if scale == 'log': + s = min(max(s, 0) + 1e-4, 1) + s = max(math.log(s) + 4, 0) * (6 / 4) else: - return (255, 0, int(255 * (6 - h))) + s = min(max(s, 0), 1) + s *= 6 + if s < 1: + return hue2rgb(4, s * 0.6) + elif s < 5: + return hue2rgb(5 - s, (s - 1) / 4 * 0.4 + 0.6) + else: + t = int((s - 5) * 255) + return (255, t, t) + +def colorize(a, colors): + a = a * (len(colors) - 1) + l = np.floor(a).astype(np.int32) + r = np.clip(l + 1, 0, len(colors) - 1) + a = a - l + result = colors[l] + (colors[r] - colors[l]) * a[..., None] + return result + +def draw_contour_map(arr, interval, smooth=True): + def _refine(a, smooth, cut=0.15, scale=0.6): + if smooth: + s = np.zeros_like(a, dtype=np.float32) + a = np.pad(a, (2, 2), 'edge') + + # approximated 5x5 gaussian filter + + s += a[2:-2, 2:-2] * 41 + + s += a[1:-3, 2:-2] * 26 + s += a[3:-1, 2:-2] * 26 + s += a[2:-2, 1:-3] * 26 + s += a[2:-2, 3:-1] * 26 + + s += a[1:-3, 1:-3] * 16 + s += a[3:-1, 3:-1] * 16 + s += a[1:-3, 3:-1] * 16 + s += a[3:-1, 1:-3] * 16 + + s += a[4:, 2:-2] * 7 + s += a[:-4, 2:-2] * 7 + s += a[2:-2, 4:] * 7 + s += a[2:-2, :-4] * 7 + + s += a[1:-3, 4:] * 4 + s += a[3:-1, :-4] * 4 + s += a[4:, 3:-1] * 4 + s += a[:-4, 1:-3] * 4 + + s += a[4:, 4:] * 1 + s += a[:-4, :-4] * 1 + s += a[4:, :-4] * 1 + s += a[:-4, 4:] * 1 + + s /= 273 + + a = (s - cut) / scale + return a.clip(0, 1) + + cv = np.floor(arr / interval) + contour_map = np.zeros_like(arr, dtype=np.float32) + contour_map[:-1] += cv[:-1] != cv[1:] + contour_map[1:] += cv[:-1] != cv[1:] + contour_map[:, :-1] += cv[:, :-1] != cv[:, 1:] + contour_map[:, 1:] += cv[:, :-1] != cv[:, 1:] + contour_map = _refine(contour_map, smooth, cut=0.25, scale=0.65) + + cv = np.floor(arr / (interval * 5)) + contour_map2 = np.zeros_like(arr, dtype=np.float32) + contour_map2[:-1] += cv[:-1] != cv[1:] + contour_map2[1:] += cv[:-1] != cv[1:] + contour_map2[:, :-1] += cv[:, :-1] != cv[:, 1:] + contour_map2[:, 1:] += cv[:, :-1] != cv[:, 1:] + contour_map2 = _refine(contour_map2, smooth, cut=0.15, scale=0.6) + contour_map = (contour_map + contour_map2) / 2 + + return contour_map topic_colors = [hue2rgb((i * 7) / 40 * 6) for i in range(40)] topic_styles = [(i * 7 // 40) % 3 for i in range(40)] +def find_best_labels_for_range(start, end, max_labels): + dist = end - start + unit = 1 + for i in range(10): + u = 10 ** (-i) + if dist < u: + continue + r = dist - round(dist / u) * u + if abs(r) < u * 0.1: + unit = u + break + steps = round(dist / unit) + while steps > max_labels: + unit *= 10 + steps = round(dist / unit) + + if steps <= max_labels / 5: + unit /= 5 + steps = round(dist / unit) + elif steps <= max_labels / 2: + unit /= 2 + steps = round(dist / unit) + + s = int(math.floor(start / unit)) + e = int(math.ceil(end / unit)) + return [unit * i for i in range(s, e + 1)] + +def estimate_confidence_interval_of_dd(alpha, p=0.95, samples=16384): + rng = np.random.RandomState(0) + alpha = np.array(alpha, dtype=np.float32) + mean = alpha / alpha.sum() + t = rng.dirichlet(alpha, samples).astype(np.float32) + t.sort(axis=0) + cnt = int(samples * (1 - p)) + i = (t[-cnt:] - t[:cnt]).argmin(0) + + o = np.array([np.searchsorted(t[:, i], m, 'right') for i, m in enumerate(mean)]) + i = np.maximum(i, o - (samples - cnt)) + + lb = t[i, np.arange(len(alpha))] + ub = t[i - cnt, np.arange(len(alpha))] + return lb, ub + +def is_iterable(obj): + try: + iter(obj) + return True + except: + return False + class DocumentFilter: def __init__(self, model, max_cache_size=5) -> None: self.model = model @@ -52,9 +191,11 @@ def __init__(self, model, max_cache_size=5) -> None: self._cached = {} self._cached_keys = [] - def _sort_and_filter(self, sort_key:int, filter_target:int, filter_value:float): + def _sort_and_filter(self, sort_key:int, filter_target:int, filter_value:float, filter_keyword:tuple, filter_metadata:str): results = [] for i, doc in enumerate(self.model.docs): + if filter_keyword and not all(kw in doc.raw.lower() for kw in filter_keyword): continue + if filter_metadata is not None and doc.metadata != filter_metadata: continue dist = doc.get_topic_dist() if dist[filter_target] < filter_value: continue if sort_key >= 0: @@ -67,37 +208,49 @@ def _sort_and_filter(self, sort_key:int, filter_target:int, filter_value:float): else: return [i for _, i in sorted(results, reverse=True)] - def _get_cached_filter_result(self, sort_key:int, filter_target:int, filter_value:float): - if sort_key < 0 and filter_value <= 0: + def _get_cached_filter_result(self, sort_key:int, filter_target:int, filter_value:float, filter_keyword:str, filter_metadata:str): + filter_keyword = tuple(filter_keyword.lower().split()) + if sort_key < 0 and filter_value <= 0 and not filter_keyword and filter_metadata is None: # return None for no filtering nor sorting return None - key = (sort_key, filter_target, filter_value) + key = (sort_key, filter_target, filter_value, filter_keyword, filter_metadata) if key in self._cached: return self._cached[key] else: - result = self._sort_and_filter(sort_key, filter_target, filter_value) + result = self._sort_and_filter(sort_key, filter_target, filter_value, filter_keyword, filter_metadata) if len(self._cached_keys) >= self.max_cache_size: del self._cached[self._cached_keys.pop(0)] self._cached[key] = result self._cached_keys.append(key) return result - def get(self, sort_key:int, filter_target:int, filter_value:float, index:slice): + def get(self, sort_key:int, filter_target:int, filter_value:float, filter_keyword:str, filter_metadata:str, index:slice): # return (doc_indices, total_docs_filtered) - result = self._get_cached_filter_result(sort_key, filter_target, filter_value) + result = self._get_cached_filter_result(sort_key, filter_target, filter_value, filter_keyword, filter_metadata) if result is None: return list(range(index.start, min(index.stop, len(self.model.docs)))), len(self.model.docs) else: return result[index], len(result) - + class ViewerHandler(http.server.SimpleHTTPRequestHandler): - handlers = [ + get_handlers = [ (r'/?', 'overview'), (r'/document/?', 'document'), (r'/topic/?', 'topic'), + (r'/topic-rel/?', 'topic_rel'), + (r'/metadata/?', 'metadata'), + (r'/tdf-map/?', 'tdf_map'), (r'/api/document/(\d+)', 'api_document'), + (r'/api/conf-interval/(\d+)/([0-9.]+)', 'api_conf_interval'), + (r'/d/topic-words\.csv', 'download_topic_words'), + (r'/d/document-top1-topic\.csv', 'download_document_top1_topic'), + (r'/d/tdf-map-([0-9]+|legend).png', 'download_tdf_map'), + ] + + post_handlers = [ + (r'/api/topic/(\d+)/label', 'api_update_topic_label'), ] num_docs_per_page = 30 @@ -110,11 +263,88 @@ def title(self): def model(self): return self.server.model + @property + def model_hash(self): + hex_chr = hex(self.model.get_hash())[2:] + if len(hex_chr) < 32: + hex_chr = '0' * (32 - len(hex_chr)) + hex_chr + return hex_chr + + @property + def available(self): + ret = {} + if 'GDMR' in type(self.model).__name__: + ret['metadata'] = True + return ret + @property def tomotopy_version(self): import tomotopy return tomotopy.__version__ + @property + def user_config(self): + return self.get_user_config(None) + + @property + def read_only(self): + return self.server.read_only + + def get_topic_label(self, k, prefix='', id_suffix=False): + label = self.get_user_config(('topic_label', k)) + if label is None: + label = f'{prefix}#{k}' + elif id_suffix: + label += f' #{k}' + return label + + def get_all_topic_labels(self, prefix='', id_suffix=False): + return [self.get_topic_label(k, prefix, id_suffix) for k in range(self.model.k)] + + def get_user_config(self, key): + if self.server.user_config is None: + if self.server.user_config_file: + try: + self.server.user_config = json.load(open(self.server.user_config_file, 'r', encoding='utf-8')) + model_hash_in_config = self.server.user_config.get('model_hash') + if model_hash_in_config is not None and model_hash_in_config != self.model_hash: + print(f'User config file is for a different model. Ignoring the file.') + self.server.user_config = {} + except FileNotFoundError: + self.server.user_config = {} + else: + self.server.user_config = {} + self.server.user_config['model_hash'] = self.model_hash + if key is None: + return self.server.user_config + + if isinstance(key, str) or not is_iterable(key): + key = [key] + + obj = self.server.user_config + for k in key: + obj = obj.get(str(k)) + if obj is None: + return obj + return obj + + def set_user_config(self, key, value): + self.get_user_config(key) + + if isinstance(key, str) or not is_iterable(key): + key = [key] + + obj = self.server.user_config + for k in key[:-1]: + k = str(k) + if k not in obj: + obj[k] = {} + obj = obj[k] + obj[str(key[-1])] = value + + if self.server.user_config_file: + json.dump(self.server.user_config, open(self.server.user_config_file, 'w', encoding='utf-8'), ensure_ascii=False, indent=2) + def render(self, **kwargs): local_vars = {} for k in dir(self): @@ -131,7 +361,7 @@ def do_GET(self): parsed = urllib.parse.urlparse(self.path) path, query = parsed.path, parsed.query self.arguments = {k:v[0] if len(v) == 1 else v for k, v in urllib.parse.parse_qs(query).items()} - for pattern, handler in self.handlers: + for pattern, handler in self.get_handlers: m = re.fullmatch(pattern, path) if m: try: @@ -143,6 +373,22 @@ def do_GET(self): return self.send_error(404) + def do_POST(self): + parsed = urllib.parse.urlparse(self.path) + path, query = parsed.path, parsed.query + self.arguments = json.loads(self.rfile.read(int(self.headers['Content-Length']))) + for pattern, handler in self.post_handlers: + m = re.fullmatch(pattern, path) + if m: + try: + getattr(self, 'post_' + handler)(*m.groups()) + except: + self.send_error(500) + traceback.print_exc() + self.wfile.write(traceback.format_exc().encode()) + return + self.send_error(404) + def get_overview(self): import tomotopy._summary as tps buf = io.StringIO() @@ -179,12 +425,22 @@ def get_document(self): sort_key = int(self.arguments.get('s', '-1')) filter_target = int(self.arguments.get('t', '0')) filter_value = float(self.arguments.get('v', '0')) + filter_keyword = self.arguments.get('sq', '') + filter_metadata = int(self.arguments.get('m', '-1')) page = int(self.arguments.get('p', '0')) + if not self.available.get('metadata') or filter_metadata < 0: + filter_metadata = -1 + md = None + else: + md = self.model.metadata_dict[filter_metadata] + doc_indices, filtered_docs = self.server.filter.get( sort_key, filter_target, filter_value / 100, + filter_keyword, + md, slice(page * self.num_docs_per_page, (page + 1) * self.num_docs_per_page) ) total_pages = (filtered_docs + self.num_docs_per_page - 1) // self.num_docs_per_page @@ -199,14 +455,34 @@ def get_document(self): self.render(action='document', page=page, total_pages=total_pages, - filtered_docs=filtered_docs if filter_value > 0 else None, + filtered_docs=filtered_docs if filter_value > 0 or filter_keyword or filter_metadata >= 0 else None, total_docs=total_docs, documents=documents, sort_key=sort_key, filter_target=filter_target, filter_value=filter_value, + filter_keyword=filter_keyword, + filter_metadata=filter_metadata, ) + def prepare_topic_doc_stats(self): + all_cnt = Counter([doc.get_topics(1)[0][0] for doc in self.model.docs]) + top1_topic_dist = [all_cnt[i] for i in range(self.model.k)] + + try: + has_metadata = len(self.model.docs[0].metadata) > 1 + except: + has_metadata = False + + if has_metadata: + top1_topic_dist_by_metadata = defaultdict(Counter) + for doc in self.model.docs: + top1_topic_dist_by_metadata[doc.metadata][doc.get_topics(1)[0][0]] += 1 + for k, cnt in top1_topic_dist_by_metadata.items(): + top1_topic_dist_by_metadata[k] = [cnt[i] for i in range(self.model.k)] + + return top1_topic_dist, top1_topic_dist_by_metadata if has_metadata else None + def get_topic(self): top_n = int(self.arguments.get('top_n', '10')) alpha = float(self.arguments.get('alpha', '0.0')) @@ -219,21 +495,125 @@ def get_topic(self): top_words = (-weighted_topic_word_dist).argsort()[:, :top_n] max_dist = topic_word_dist.max() for k, top_word in enumerate(top_words): - topic_words = [(self.model.vocabs[w], topic_word_dist[k, w]) for w in top_word] + topic_words = [(self.model.vocabs[w], w, topic_word_dist[k, w]) for w in top_word] topics.append(Topic(k, topic_words)) else: for k in range(self.model.k): - topic_words = self.model.get_topic_words(k, top_n) - max_dist = max(max_dist, topic_words[0][1]) + topic_words = self.model.get_topic_words(k, top_n, return_id=True) + max_dist = max(max_dist, topic_words[0][-1]) topics.append(Topic(k, topic_words)) + top1_topic_dist, top1_topic_dist_by_metadata = self.prepare_topic_doc_stats() + self.send_response(200) self.send_header('Content-type', 'text/html') self.end_headers() self.render(action='topic', topics=topics, max_dist=max_dist, - top_n=top_n,) + top_n=top_n, + top1_topic_dist=top1_topic_dist, + top1_topic_dist_by_metadata=top1_topic_dist_by_metadata, + ) + + def get_topic_rel(self): + topic_word_dist = np.stack([self.model.get_topic_word_dist(k) for k in range(self.model.k)]) + overlaps = np.minimum(topic_word_dist[:, None], topic_word_dist[None]).sum(-1) + similar_pairs = np.stack(np.unravel_index((-np.triu(overlaps, 1)).flatten().argsort(), overlaps.shape), -1) + similar_pairs = similar_pairs[similar_pairs[:, 0] != similar_pairs[:, 1]] + most_similars = (2 * np.eye(len(overlaps)) - overlaps).argsort()[:, :-1] + + self.send_response(200) + self.send_header('Content-type', 'text/html') + self.end_headers() + self.render(action='topic-rel', + overlaps=overlaps, + similar_pairs=similar_pairs, + most_similars=most_similars, + ) + + def prepare_metadata(self): + axis = int(self.arguments.get('axis', '0')) + x = self.arguments.get('x', '') + resolution = int(self.arguments.get('r', '33')) + numeric_metadata = self.model.metadata_range + if axis < 0 or axis >= len(numeric_metadata): + axis = 0 + + if x: + x = list(map(float, x.split(','))) + x = list(zip(x[::2], x[1::2])) + else: + x = [((s, e) if i == axis else (s, s)) for i, (s, e) in enumerate(numeric_metadata)] + + start, end = zip(*x) + num = [resolution if i == axis else 1 for i in range(len(x))] + squeeze_axis = tuple(i for i in range(len(x)) if i != axis) + return start, end, num, squeeze_axis, axis, numeric_metadata + + def compute_data_density(self, x_values, axis, categorical_metadata): + dist = defaultdict(list) + for d in self.model.docs: + dist[d.metadata].append(d.numeric_metadata[axis]) + + s, e = self.model.metadata_range[axis] + kernel_size = (e - s) / (self.model.degrees[axis] + 1) + + densities = [] + for c in categorical_metadata: + points = np.array(dist[c], dtype=np.float32) + density = np.exp(-((x_values[:, None] - points) / kernel_size) ** 2).sum(-1) + density /= density.max() + densities.append(density) + return densities + + def get_metadata(self): + (start, end, num, squeeze_axis, axis, numeric_metadata + ) = self.prepare_metadata() + max_labels = int(self.arguments.get('max_labels', '15')) + categorical_metadata = self.model.metadata_dict + + x_values = np.linspace(start[axis], end[axis], num[axis], dtype=np.float32) + data_density = self.compute_data_density(x_values, axis, categorical_metadata) + boundaries = np.array(find_best_labels_for_range(x_values[0], x_values[-1], max_labels)) + t = (np.searchsorted(boundaries, x_values, 'right') - 1).clip(0) + x_labels = [f'{boundaries[t[i]]:g}' if i == 0 or t[i - 1] != t[i] else '' for i in range(len(x_values))] + + cats = np.stack([self.model.tdf_linspace(start, end, num, metadata=c).squeeze(squeeze_axis) for c in categorical_metadata]) + + self.send_response(200) + self.send_header('Content-type', 'text/html') + self.end_headers() + self.render(action='metadata', + categorical_metadata=categorical_metadata, + numeric_metadata=numeric_metadata, + range_start=start, + range_end=end, + x_values=x_values, + x_labels=x_labels, + axis=axis, + cats=cats, + data_density=data_density, + ) + + def get_tdf_map(self): + x = int(self.arguments.get('x', '0')) + y = int(self.arguments.get('y', '1')) + width = int(self.arguments.get('w', '640')) + height = int(self.arguments.get('h', '480')) + contour_interval = float(self.arguments.get('s', '0.2')) + smooth = bool(int(self.arguments.get('smooth', '1'))) + + self.send_response(200) + self.send_header('Content-type', 'text/html') + self.end_headers() + self.render(action='tdf-map', + x_axis=x, + y_axis=y, + width=width, + height=height, + contour_interval=contour_interval, + smooth=smooth,) def get_api_document(self, doc_id): doc_id = int(doc_id) @@ -252,11 +632,21 @@ def get_api_document(self, doc_id): chunks.append(html.escape(raw[last:])) html_cont = '

' + ''.join(chunks).strip().replace('\n', '
') + '

' + meta = [] + if hasattr(doc, 'metadata'): + meta.append(f'') + if hasattr(doc, 'multi_metadata'): + meta.append(f'') + if hasattr(doc, 'numeric_metadata'): + meta.append(f'') + if meta: + html_cont = '\n'.join(meta) + '\n' + html_cont + chunks = [''] for topic_id, dist in doc.get_topics(top_n=-1): chunks.append( f''' - + @@ -268,6 +658,183 @@ def get_api_document(self, doc_id): self.end_headers() self.wfile.write(html_cont.encode()) + def get_api_conf_interval(self, cid, p=0.95): + cid = int(cid) + p = float(p) + (start, end, num, squeeze_axis, axis, numeric_metadata + ) = self.prepare_metadata() + categorical_metadata = self.model.metadata_dict + alphas = np.exp(self.model.tdf_linspace(start, end, num, metadata=categorical_metadata[cid], normalize=False).squeeze(squeeze_axis)) + lbs = [] + ubs = [] + for alpha in alphas: + lb, ub = estimate_confidence_interval_of_dd(alpha, p=p, samples=10000) + lbs.append(lb) + ubs.append(ub) + lbs = np.stack(lbs, axis=-1).tolist() + ubs = np.stack(ubs, axis=-1).tolist() + + self.send_response(200) + self.send_header('Content-type', 'application/json') + self.end_headers() + + self.wfile.write(json.dumps({'data':{'cid':cid, 'p':p, 'lbs': lbs, 'ubs': ubs}}, ensure_ascii=False).encode()) + + def post_api_update_topic_label(self, topic_id): + if self.read_only: + self.send_error(403) + return + topic_id = int(topic_id) + label = self.arguments.get('label', '') or None + self.set_user_config(('topic_label', topic_id), label) + self.send_response(200) + self.send_header('Content-type', 'application/json') + self.end_headers() + self.wfile.write(json.dumps({'topic_id':topic_id, 'label':label}, ensure_ascii=False).encode()) + + def get_download_topic_words(self): + n = int(self.arguments.get('n', '10')) + csv_buf = io.StringIO() + writer = csv.writer(csv_buf) + headers = [''] + words = [] + for k in range(self.model.k): + headers.append(self.get_topic_label(k, prefix='Topic ', id_suffix=True)) + headers.append('Prob.') + words.append(self.model.get_topic_words(k, top_n=n)) + + writer.writerow(headers) + for i in range(n): + row = [i + 1] + for k in range(self.model.k): + row.extend(words[k][i]) + writer.writerow(row) + + self.send_response(200) + self.send_header('Content-type', 'text/csv') + self.send_header('Content-Disposition', 'attachment; filename="topic-words.csv"') + self.end_headers() + self.wfile.write(csv_buf.getvalue().encode('utf-8-sig')) + + def get_download_document_top1_topic(self): + metadata = int(self.arguments.get('m', '0')) + csv_buf = io.StringIO() + writer = csv.writer(csv_buf) + top1_topic_dist, top1_topic_dist_by_metadata = self.prepare_topic_doc_stats() + if metadata: + headers = ['', *self.model.metadata_dict] + writer.writerow(headers) + for k in range(self.model.k): + row = [self.get_topic_label(k, prefix='Topic ', id_suffix=True)] + for m in self.model.metadata_dict: + row.append(top1_topic_dist_by_metadata[m][k]) + writer.writerow(row) + else: + headers = ['', 'All'] + writer.writerow(headers) + for k, cnt in enumerate(top1_topic_dist): + writer.writerow([self.get_topic_label(k, prefix='Topic ', id_suffix=True), cnt]) + + self.send_response(200) + self.send_header('Content-type', 'text/csv') + if metadata: + self.send_header('Content-Disposition', 'attachment; filename="document-top1-topic-by-metadata.csv"') + else: + self.send_header('Content-Disposition', 'attachment; filename="document-top1-topic.csv"') + self.end_headers() + self.wfile.write(csv_buf.getvalue().encode('utf-8-sig')) + + def __eq__(self, other): + return self.model.get_hash() == other.model.get_hash() + + def __hash__(self): + return self.model.get_hash() + + @functools.lru_cache(maxsize=128) + def cached_tdf_linspace(self, start, end, num, metadata=""): + return self.model.tdf_linspace(start, end, num, metadata=metadata) + + @functools.lru_cache(maxsize=128) + def cache_tdf_map_img(self, topic_id, x, y, w, h, r, contour_interval, smooth): + from PIL import Image + start, end = zip(*r) + num = [1] * len(start) + num[x] = w + num[y] = h + + metadata = int(self.arguments.get('m', '0')) + metadata = self.model.metadata_dict[metadata] if metadata >= 0 else "" + + td = self.cached_tdf_linspace(tuple(start), tuple(end), tuple(num), metadata) + td = td.transpose([-1, y, x] + [i for i in range(len(start)) if i not in (x, y)]).squeeze() + max_val = np.log(td.max() + 1e-9) - 1 + min_val = -7 + if topic_id == 'legend': + logits = np.linspace(min_val, max_val, w)[None] + logits = np.repeat(logits, h, 0) + smooth = False + else: + logits = np.log(td[topic_id] + 1e-9) + logits = logits[::-1] # invert y-axis + scaled = (logits - min_val) / (max_val - min_val) + scaled = np.clip(scaled, 0, 1) + + contour_map = draw_contour_map(logits, contour_interval, smooth) + + colors = np.array([ + [0, 0, 0], + [0, 0, 0.7], + [0, 0.75, 0.75], + [0, 0.8, 0], + [0.85, 0.85, 0], + [1, 0, 0], + [1, 1, 1], + ], dtype=np.float32) + colorized = colorize(scaled, colors) + if topic_id == 'legend': + is_sub_grid = contour_map[0] < 1 + contour_map[:-int(h * 0.32)] = 0 + contour_map[:-int(h * 0.16), is_sub_grid] = 0 + contour_map = contour_map.clip(0, 1) + colorized *= 1 - contour_map[..., None] + img = Image.fromarray((colorized * 255).astype(np.uint8), 'RGB') + img_buf = io.BytesIO() + img.save(img_buf, format='PNG') + img_buf.seek(0) + return img_buf.read() + + def get_download_tdf_map(self, topic_id): + if not hasattr(self.model, 'tdf_linspace'): + self.send_error(404) + return + if topic_id == 'legend': + pass + else: + topic_id = int(topic_id) + if topic_id >= self.model.k: + self.send_error(404) + return + + x = int(self.arguments.get('x', '0')) + y = int(self.arguments.get('y', '1')) + w = int(self.arguments.get('w', '640')) + h = int(self.arguments.get('h', '480')) + contour_interval = float(self.arguments.get('s', '0.2')) + smooth = bool(int(self.arguments.get('smooth', '1'))) + + r = self.arguments.get('r', '') + if r: + r = list(map(float, x.split(','))) + r = list(zip(x[::2], x[1::2])) + else: + r = self.model.metadata_range + + img_buf = self.cache_tdf_map_img(topic_id, x, y, w, h, tuple(r), contour_interval, smooth) + self.send_response(200) + self.send_header('Content-type', 'image/png') + self.end_headers() + self.wfile.write(img_buf) + def _repl(m): if m.group().startswith('{{'): inner = m.group(1) @@ -312,7 +879,30 @@ def _prepare_template(): compiled_template = compile('\n'.join(codes), 'template.html', 'exec') return compiled_template -def open_viewer(model, host='localhost', port=80, title=None): +def open_viewer(model, host='localhost', port=80, title=None, user_config_file=None, read_only=False): + ''' +Run a server for topic model viewer + +Parameters +---------- +model: tomotopy.LDAModel or its derived class + A trained topic model instance to be visualized. +host: str + The host name to bind the server. Default is 'localhost'. +port: int + The port number to bind the server. Default is 80. +title: str + The title of the viewer in a web browser. Default is the class name of the model. +user_config_file: str + The path to a JSON file to store the user configurations. Default is `None`. If None, the user configurations are not saved. +read_only: bool + If True, the viewer will be read-only, that is the user cannot change topic labels. Default is False. + +Note +---- +It is not recommended to use it in a production web service, +because this uses python's built-in `http.server` module which is not designed for high-performance production environments. + ''' import tomotopy as tp if not isinstance(model, tp.LDAModel): raise ValueError(f'`model` must be an instance of tomotopy.LDAModel, but {model!r} was given.') @@ -326,7 +916,9 @@ def open_viewer(model, host='localhost', port=80, title=None): httpd.title = title httpd.model = model httpd.template = template + httpd.user_config_file = user_config_file + httpd.user_config = None + httpd.read_only = read_only httpd.filter = DocumentFilter(model) print(f'Serving a topic model viewer at http://{httpd.server_address[0]}:{httpd.server_address[1]}/') httpd.serve_forever() -
Topic # {topic_id}{self.get_topic_label(topic_id, prefix="Topic ", id_suffix=True)}
{dist:.3%}