-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathplot_label_dists.py
executable file
·36 lines (31 loc) · 1.09 KB
/
plot_label_dists.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#! /usr/bin/python3
# Imports
from utils.data import SentimentDataset
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
plt.style.use('bmh')
# config variables
out_file = "figures/label_distributions.png"
path_template = "data/{}"
train = 'train.txt'
wordlist = "wordlist.txt"
productlist = "prdlist.txt"
userlist = "usrlist.txt"
datasets = ["IMDB", "yelp13", "yelp14"]
# Load each datset and add a plot for its label distribution
fig, ax = plt.subplots(1, len(datasets), figsize=(12, 3))
for i, s in enumerate(datasets):
path = Path(path_template.format(s))
dat = SentimentDataset(str(path/train),
str(path/userlist),
str(path/productlist),
str(path/wordlist))
x = np.array(dat.documents["label"])
ax[i].hist(x, density=True, align='left',
bins=range(max(x)+2), color=f"C{i}")
ax[i].set_xticks(range(max(x)+1))
ax[i].set_facecolor((1.0, 1.00, 1.00))
ax[i].set_title(s)
# Save image to folder
fig.savefig(out_file, dpi=200, pad_inches=0, bbox_inches='tight')