-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTreeBankDataSet.py
95 lines (68 loc) · 2.85 KB
/
TreeBankDataSet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# -*- coding: utf-8 -*-
"""
Created on Sun Mar 24 15:46:12 2019
@author: durgesh singh
"""
from torch.utils.data import Dataset
import nltk
from nltk.corpus import treebank
from nltk.corpus import brown
from collections import defaultdict
import torch
class CustomDataset(Dataset):
def __init__(self,dname='brown'):
super().__init__()
data = None
#selecting the datset
if dname =='treebank':
if len(treebank.words()) == 0:
nltk.download('treebank')
data = treebank.tagged_sents(tagset='universal')
elif dname == 'brown':
if len(brown.words()) == 0:
nltk.download('brown')
data = brown.tagged_sents(tagset='universal')
self.data=data
#print(data[0:1])
vocab,tags =self._build_vocab()
max_sent_len = max(map(len, data))
self.max_sent_len = max_sent_len
self.word_to_idx = defaultdict(lambda:0, {word:idx for idx,word in enumerate(vocab)})
self.idx_to_word = {idx:word for word,idx in self.word_to_idx.items()}
self.tag_to_idx = {tag:idx for idx,tag in enumerate(tags)}
self.idx_to_tag = {idx:tag for tag,idx in self.tag_to_idx.items()}
self.sen_list,self.tag_list = self._convert_to_num()
def get_target_size(self):
return len(self.tag_to_idx)
def get_vocab_size(self):
return len(self.word_to_idx)
def _convert_to_num(self):
data = self.data
max_sent_len = self.max_sent_len
sent_list=[]
taggs_list=[]
for sen in data:
num_row = [self.word_to_idx[word.lower()] for word,tag in sen]
tag_row = [self.tag_to_idx[tag] for word,tag in sen]
num_row = num_row +[0]*(max_sent_len-len(num_row))
tag_row = tag_row +[0]*(max_sent_len-len(tag_row))
num_row = torch.tensor(num_row)
tag_row = torch.tensor(tag_row)
sent_list.append(num_row)
taggs_list.append(tag_row)
return sent_list,taggs_list
def _build_vocab(self):
data=self.data
vocabset=set()
tagset = set()
all_sents_tags = [[(word.lower(), tag) for word, tag in sentence] for sentence in data]
for sent_tag in all_sents_tags:
sent,tags = zip(*sent_tag)
vocabset.update(sent)
tagset.update(tags)
return ['<UNK>','<EOS>']+list(vocabset),list(tagset)
def __len__(self):
return len(self.sen_list)
def __getitem__(self,idx):
return self.sen_list[idx],self.tag_list[idx]
cdataset = CustomDataset(dname='brown')