diff --git a/.gitignore b/.gitignore index e7f41300..aea81151 100755 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,7 @@ mySVM.joblib .ipynb_checkpoints/* *.ipynb_checkpoints* data +models *.json *.log *.txt diff --git a/main.py b/main.py index 2bed4878..a5a5f2b0 100755 --- a/main.py +++ b/main.py @@ -12,7 +12,6 @@ # TODO: eliminate features that occur only n times ? # Do the Moisl Selection ? -# TODO: free up memory as the script goes by deleting unnecessary objects if __name__ == '__main__': @@ -62,7 +61,8 @@ keep_punct=args.keep_punct, keep_sym=args.keep_sym, max_samples=args.max_samples) else: - myTexts = tuy.load_texts(args.s, identify_lang=args.identify_lang, format=args.x, keep_punct=args.keep_punct, keep_sym=args.keep_sym) + myTexts = tuy.load_texts(args.s, identify_lang=args.identify_lang, format=args.x, keep_punct=args.keep_punct, + keep_sym=args.keep_sym, max_samples=args.max_samples) print(".......getting features.......") @@ -96,15 +96,6 @@ print(".......feeding data frame.......") - #feats = pandas.DataFrame(columns=list(feat_list), index=unique_texts) - - - # with Pool(args.p) as pool: - # print(args.p) - # target = zip(myTexts, [feat_list] * len(myTexts)) - # with tqdm.tqdm(total=len(myTexts)) as pbar: - # for text, local_freqs in pool.map(count_process, target): - loc = {} for t in tqdm.tqdm(myTexts): diff --git a/superstyl/preproc/tuyau.py b/superstyl/preproc/tuyau.py index 200e18a6..133d1c6e 100755 --- a/superstyl/preproc/tuyau.py +++ b/superstyl/preproc/tuyau.py @@ -92,8 +92,35 @@ def normalise(text, keep_punct=False, keep_sym=False): return out +def max_sampling(myTexts, max_samples=10): + """ + Select a random number of samples, equal to max_samples, for authors or classes that have more than max_samples + :param myTexts: the input myTexts object + :param max_samples: the maximum number of samples for any class + :return: a myTexts object, with the resulting selection of samples + """ + autsCounts = dict() + for text in myTexts: + if text['aut'] not in autsCounts.keys(): + autsCounts[text['aut']] = 1 + + else: + autsCounts[text['aut']] += 1 + + for autCount in autsCounts.items(): + if autCount[1] > max_samples: + # get random selection + toBeSelected = [text for text in myTexts if text['aut'] == autCount[0]] + toBeSelected = random.sample(toBeSelected, k=max_samples) + # Great, now remove all texts from this author from our samples + myTexts = [text for text in myTexts if text['aut'] != autCount[0]] + # and now concat + myTexts = myTexts + toBeSelected + + return myTexts + -def load_texts(paths, identify_lang=False, format="txt", keep_punct=False, keep_sym=False): +def load_texts(paths, identify_lang=False, format="txt", keep_punct=False, keep_sym=False, max_samples=10): """ Loads a collection of documents into a 'myTexts' object for further processing. TODO: a proper class @@ -102,11 +129,11 @@ def load_texts(paths, identify_lang=False, format="txt", keep_punct=False, keep_ :param format: format of the source files (implemented values: txt [default], xml) :param keep_punct: whether or not to keep punctuation and caps. :param keep_sym: whether or not to keep punctuation, caps, letter variants and numbers (no unidecode). + :param max_samples: the maximum number of samples for any class :return: a myTexts object """ myTexts = [] - # langCerts = [] for path in paths: name = path.split('/')[-1] @@ -127,20 +154,9 @@ def load_texts(paths, identify_lang=False, format="txt", keep_punct=False, keep_ myTexts.append({"name": name, "aut": aut, "text": text, "lang": lang}) - # if cert < 1: - # langCerts.append((lang, name, cert)) - - # directory = "train_txt/" + lang + "/" + aut + "/" - - # if not os.path.exists(directory): - # os.makedirs(directory) - - # with open(directory + name + ".txt", "w") as out: - # out.write(text) + if max_samples is not None: + myTexts = max_sampling(myTexts, max_samples=max_samples) - # with open("lang_certs.csv", 'w') as out: - # for line in langCerts: - # out.write("{}\t{}\t{}\t\n".format(line[0], line[1], float(line[2]))) return myTexts @@ -219,8 +235,9 @@ def docs_to_samples(paths, size, step=None, units="verses", feature="tokens", fo :param feature: type of tokens to extract (default is tokens, not lemmas or POS) :param format: type of document, one of full text, TEI or simple XML (ONLY TEI and TXT IMPLEMENTED) :param keep_punct: whether or not to keep punctuation and caps. - :param max_samples: maximum number of samples per author. + :param max_samples: maximum number of samples per author/class. :param identify_lang: whether or not try to identify lang (default: False) + :return: a myTexts object """ myTexts = [] for path in paths: @@ -246,22 +263,6 @@ def docs_to_samples(paths, size, step=None, units="verses", feature="tokens", fo myTexts.append({"name": name, "aut": aut, "text": text, "lang": lang}) if max_samples is not None: - autsCounts = dict() - for text in myTexts: - if text['aut'] not in autsCounts.keys(): - autsCounts[text['aut']] = 1 - - else: - autsCounts[text['aut']] += 1 - - for autCount in autsCounts.items(): - if autCount[1] > max_samples: - # get random selection - toBeSelected = [text for text in myTexts if text['aut'] == autCount[0]] - toBeSelected = random.sample(toBeSelected, k=max_samples) - # Great, now remove all texts from this author from our samples - myTexts = [text for text in myTexts if text['aut'] != autCount[0]] - # and now concat - myTexts = myTexts + toBeSelected + myTexts = max_sampling(myTexts, max_samples=max_samples) return myTexts diff --git a/tests/test_main.py b/tests/test_main.py index aa3cf9bc..a10d0bc7 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -98,6 +98,8 @@ def test_get_counts(self): #TODO: a lot more tests +# TODO: tests for SVM, etc. +# Test all options of main commands, see if they are accepted or not if __name__ == '__main__': unittest.main()