Skip to content

Commit

Permalink
fix scikit learn warning
Browse files Browse the repository at this point in the history
  • Loading branch information
Jean-Baptiste-Camps committed Feb 16, 2024
1 parent 30c7a9c commit eb83762
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,17 @@
parser.add_argument('--identify_lang', action='store_true',
help="if true, should the language of each text be guessed, using langdetect (default is False)",
default=False)
parser.add_argument('--embedding', action="store", help="optional path to a Glove embedding to compute frequencies among a set of semantic neighbourgs (i.e., pseudo-paronyms)",
parser.add_argument('--embedding', action="store", help="optional path to a word2vec embedding in txt format to compute frequencies among a set of semantic neighbourgs (i.e., pseudo-paronyms)",
default=False)
parser.add_argument('--neighbouring_size', action="store", help="size of semantic neighbouring in the embedding (n closest neighbours)",
default=10)
default=10, type=int)
args = parser.parse_args()

embeddedFreqs = False
if args.embedding:
print(".......loading embedding.......")
args.absolute_freqs = True # we need absolute freqs as a basis for embedded frequencies
embeddings_dict = embed.load_glove_embeddings(args.embedding)
model = embed.load_embeddings(args.embedding)
embeddedFreqs = True

print(".......loading texts.......")
Expand Down Expand Up @@ -90,7 +90,7 @@

if args.embedding:
print(".......embedding counts.......")
myTexts = embed.get_embedded_counts(myTexts, feat_list, embeddings_dict, topn=args.neighbouring_size)
myTexts = embed.get_embedded_counts(myTexts, feat_list, model, topn=args.neighbouring_size)

unique_texts = [text["name"] for text in myTexts]

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ lxml>=4.9.1
nltk>=3.6.6
numpy>=1.26.4
pybind11>=2.8.1
scikit-learn>=1.2.1
scikit-learn>=1.3.0
scipy>=1.10.0
six>=1.16.0
tqdm>=4.64.1
Expand Down
2 changes: 1 addition & 1 deletion superstyl/svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def train_svm(train, test, cross_validate=None, k=10, dim_reduc=None, norms=True

if kernel == "LinearSVC":
# try a faster one
estimators.append(('model', sk.LinearSVC(class_weight=cw)))
estimators.append(('model', sk.LinearSVC(class_weight=cw, dual="auto")))
# classif = sk.LinearSVC()

else:
Expand Down

0 comments on commit eb83762

Please sign in to comment.