Skip to content

Commit

Permalink
fixing and extending max samples
Browse files Browse the repository at this point in the history
  • Loading branch information
Jean-Baptiste-Camps committed Feb 16, 2024
1 parent e56198a commit e89c602
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 44 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ mySVM.joblib
.ipynb_checkpoints/*
*.ipynb_checkpoints*
data
models
*.json
*.log
*.txt
Expand Down
13 changes: 2 additions & 11 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

# TODO: eliminate features that occur only n times ?
# Do the Moisl Selection ?
# TODO: free up memory as the script goes by deleting unnecessary objects

if __name__ == '__main__':

Expand Down Expand Up @@ -62,7 +61,8 @@
keep_punct=args.keep_punct, keep_sym=args.keep_sym, max_samples=args.max_samples)

else:
myTexts = tuy.load_texts(args.s, identify_lang=args.identify_lang, format=args.x, keep_punct=args.keep_punct, keep_sym=args.keep_sym)
myTexts = tuy.load_texts(args.s, identify_lang=args.identify_lang, format=args.x, keep_punct=args.keep_punct,
keep_sym=args.keep_sym, max_samples=args.max_samples)

print(".......getting features.......")

Expand Down Expand Up @@ -96,15 +96,6 @@

print(".......feeding data frame.......")

#feats = pandas.DataFrame(columns=list(feat_list), index=unique_texts)


# with Pool(args.p) as pool:
# print(args.p)
# target = zip(myTexts, [feat_list] * len(myTexts))
# with tqdm.tqdm(total=len(myTexts)) as pbar:
# for text, local_freqs in pool.map(count_process, target):

loc = {}

for t in tqdm.tqdm(myTexts):
Expand Down
67 changes: 34 additions & 33 deletions superstyl/preproc/tuyau.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,35 @@ def normalise(text, keep_punct=False, keep_sym=False):

return out

def max_sampling(myTexts, max_samples=10):
"""
Select a random number of samples, equal to max_samples, for authors or classes that have more than max_samples
:param myTexts: the input myTexts object
:param max_samples: the maximum number of samples for any class
:return: a myTexts object, with the resulting selection of samples
"""
autsCounts = dict()
for text in myTexts:
if text['aut'] not in autsCounts.keys():
autsCounts[text['aut']] = 1

Check warning on line 105 in superstyl/preproc/tuyau.py

View check run for this annotation

Codecov / codecov/patch

superstyl/preproc/tuyau.py#L102-L105

Added lines #L102 - L105 were not covered by tests

else:
autsCounts[text['aut']] += 1

Check warning on line 108 in superstyl/preproc/tuyau.py

View check run for this annotation

Codecov / codecov/patch

superstyl/preproc/tuyau.py#L108

Added line #L108 was not covered by tests

for autCount in autsCounts.items():
if autCount[1] > max_samples:

Check warning on line 111 in superstyl/preproc/tuyau.py

View check run for this annotation

Codecov / codecov/patch

superstyl/preproc/tuyau.py#L110-L111

Added lines #L110 - L111 were not covered by tests
# get random selection
toBeSelected = [text for text in myTexts if text['aut'] == autCount[0]]
toBeSelected = random.sample(toBeSelected, k=max_samples)

Check warning on line 114 in superstyl/preproc/tuyau.py

View check run for this annotation

Codecov / codecov/patch

superstyl/preproc/tuyau.py#L113-L114

Added lines #L113 - L114 were not covered by tests
# Great, now remove all texts from this author from our samples
myTexts = [text for text in myTexts if text['aut'] != autCount[0]]

Check warning on line 116 in superstyl/preproc/tuyau.py

View check run for this annotation

Codecov / codecov/patch

superstyl/preproc/tuyau.py#L116

Added line #L116 was not covered by tests
# and now concat
myTexts = myTexts + toBeSelected

Check warning on line 118 in superstyl/preproc/tuyau.py

View check run for this annotation

Codecov / codecov/patch

superstyl/preproc/tuyau.py#L118

Added line #L118 was not covered by tests

return myTexts

Check warning on line 120 in superstyl/preproc/tuyau.py

View check run for this annotation

Codecov / codecov/patch

superstyl/preproc/tuyau.py#L120

Added line #L120 was not covered by tests


def load_texts(paths, identify_lang=False, format="txt", keep_punct=False, keep_sym=False):
def load_texts(paths, identify_lang=False, format="txt", keep_punct=False, keep_sym=False, max_samples=10):
"""
Loads a collection of documents into a 'myTexts' object for further processing.
TODO: a proper class
Expand All @@ -102,11 +129,11 @@ def load_texts(paths, identify_lang=False, format="txt", keep_punct=False, keep_
:param format: format of the source files (implemented values: txt [default], xml)
:param keep_punct: whether or not to keep punctuation and caps.
:param keep_sym: whether or not to keep punctuation, caps, letter variants and numbers (no unidecode).
:param max_samples: the maximum number of samples for any class
:return: a myTexts object
"""

myTexts = []
# langCerts = []

for path in paths:
name = path.split('/')[-1]
Expand All @@ -127,20 +154,9 @@ def load_texts(paths, identify_lang=False, format="txt", keep_punct=False, keep_

myTexts.append({"name": name, "aut": aut, "text": text, "lang": lang})

# if cert < 1:
# langCerts.append((lang, name, cert))

# directory = "train_txt/" + lang + "/" + aut + "/"

# if not os.path.exists(directory):
# os.makedirs(directory)

# with open(directory + name + ".txt", "w") as out:
# out.write(text)
if max_samples is not None:
myTexts = max_sampling(myTexts, max_samples=max_samples)

Check warning on line 158 in superstyl/preproc/tuyau.py

View check run for this annotation

Codecov / codecov/patch

superstyl/preproc/tuyau.py#L157-L158

Added lines #L157 - L158 were not covered by tests

# with open("lang_certs.csv", 'w') as out:
# for line in langCerts:
# out.write("{}\t{}\t{}\t\n".format(line[0], line[1], float(line[2])))
return myTexts


Expand Down Expand Up @@ -219,8 +235,9 @@ def docs_to_samples(paths, size, step=None, units="verses", feature="tokens", fo
:param feature: type of tokens to extract (default is tokens, not lemmas or POS)
:param format: type of document, one of full text, TEI or simple XML (ONLY TEI and TXT IMPLEMENTED)
:param keep_punct: whether or not to keep punctuation and caps.
:param max_samples: maximum number of samples per author.
:param max_samples: maximum number of samples per author/class.
:param identify_lang: whether or not try to identify lang (default: False)
:return: a myTexts object
"""
myTexts = []
for path in paths:
Expand All @@ -246,22 +263,6 @@ def docs_to_samples(paths, size, step=None, units="verses", feature="tokens", fo
myTexts.append({"name": name, "aut": aut, "text": text, "lang": lang})

if max_samples is not None:
autsCounts = dict()
for text in myTexts:
if text['aut'] not in autsCounts.keys():
autsCounts[text['aut']] = 1

else:
autsCounts[text['aut']] += 1

for autCount in autsCounts.items():
if autCount[1] > max_samples:
# get random selection
toBeSelected = [text for text in myTexts if text['aut'] == autCount[0]]
toBeSelected = random.sample(toBeSelected, k=max_samples)
# Great, now remove all texts from this author from our samples
myTexts = [text for text in myTexts if text['aut'] != autCount[0]]
# and now concat
myTexts = myTexts + toBeSelected
myTexts = max_sampling(myTexts, max_samples=max_samples)

Check warning on line 266 in superstyl/preproc/tuyau.py

View check run for this annotation

Codecov / codecov/patch

superstyl/preproc/tuyau.py#L266

Added line #L266 was not covered by tests

return myTexts
2 changes: 2 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def test_get_counts(self):
#TODO: a lot more tests


# TODO: tests for SVM, etc.
# Test all options of main commands, see if they are accepted or not

if __name__ == '__main__':
unittest.main()

0 comments on commit e89c602

Please sign in to comment.