Skip to content

Commit

Permalink
Merge pull request #261 from ntumlgroup/update_api
Browse files Browse the repository at this point in the history
Update APIs for reading data
  • Loading branch information
Eleven1Liu authored Feb 19, 2023
2 parents ccb7e2c + 98c351d commit 8f499f7
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 30 deletions.
24 changes: 15 additions & 9 deletions libmultilabel/linear/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def load_data(self, training_data: Union[str, pd.DataFrame] = None,
self.include_test_labels = include_test_labels

if self.data_format == 'txt' or 'dataframe':
data = self._load_libmultilabel(training_data, test_data, eval)
data = self._load_text(training_data, test_data, eval)
elif self.data_format == 'svm':
data = self._load_svm(training_data, test_data, eval)

Expand All @@ -86,18 +86,12 @@ def load_data(self, training_data: Union[str, pd.DataFrame] = None,

return data

def _load_libmultilabel(self, training_data, test_data, eval) -> 'dict[str, dict]':
def _load_text(self, training_data, test_data, eval) -> 'dict[str, dict]':
datasets = defaultdict(dict)
if test_data is not None:
if self.data_format == 'txt':
test_data = pd.read_csv(test_data, sep='\t', header=None,
error_bad_lines=False, warn_bad_lines=True, quoting=csv.QUOTE_NONE).fillna('')
test = read_libmultilabel_format(test_data)

if not eval:
if self.data_format == 'txt':
training_data = pd.read_csv(training_data, sep='\t', header=None,
error_bad_lines=False, warn_bad_lines=True, quoting=csv.QUOTE_NONE).fillna('')
train = read_libmultilabel_format(training_data)
self._generate_tfidf(train['text'])

Expand Down Expand Up @@ -145,7 +139,18 @@ def _generate_label_mapping(self, labels, classes=None):
self.binarizer.fit(labels)


def read_libmultilabel_format(data: pd.DataFrame) -> 'dict[str,list[str]]':
def read_libmultilabel_format(data: Union[str, pd.DataFrame]) -> 'dict[str,list[str]]':
"""Read multi-label text data from file or pandas dataframe.
Args:
data (Union[str, pd.DataFrame]): A file path to data in `LibMultiLabel format <https://www.csie.ntu.edu.tw/~cjlin/libmultilabel/cli/ov_data_format.html#libmultilabel-format>`_
or a pandas dataframe contains index (optional), label, and text.
Returns:
dict[str,list[str]]: A dictionary with a list of index (optional), label, and text.
"""
if isinstance(data, str):
data = pd.read_csv(data, sep='\t', header=None,
on_bad_lines='warn', quoting=csv.QUOTE_NONE).fillna('')
data = data.astype(str)
if data.shape[1] == 2:
data.columns = ['label', 'text']
Expand All @@ -157,6 +162,7 @@ def read_libmultilabel_format(data: pd.DataFrame) -> 'dict[str,list[str]]':
data['label'] = data['label'].map(lambda s: s.split())
return data.to_dict('list')


def read_libsvm_format(file_path: str) -> 'tuple[list[list[int]], sparse.csr_matrix]':
"""Read multi-label LIBSVM-format data.
Expand Down
28 changes: 10 additions & 18 deletions libmultilabel/nn/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ def _load_raw_data(data, is_test=False, tokenize_text=True, remove_no_label_data
Returns:
pandas.DataFrame: Data composed of index, label, and tokenized text.
"""
if isinstance(data, str):
logging.info(f'Load data from {data}.')
data = pd.read_csv(data, sep='\t', header=None,
on_bad_lines='warn', quoting=csv.QUOTE_NONE).fillna('')
data = data.astype(str)
if data.shape[1] == 2:
data.columns = ['label', 'text']
Expand Down Expand Up @@ -197,31 +201,19 @@ def load_datasets(

datasets = {}
if training_data is not None:
if isinstance(training_data, str):
logging.info(f'Load data from {training_data}.')
training_data = pd.read_csv(training_data, sep='\t', header=None,
error_bad_lines=False, warn_bad_lines=True, quoting=csv.QUOTE_NONE).fillna('')
datasets['train'] = _load_raw_data(training_data, tokenize_text=tokenize_text,
remove_no_label_data=remove_no_label_data)
datasets['train'] = _load_raw_data(
training_data, tokenize_text=tokenize_text, remove_no_label_data=remove_no_label_data)

if val_data is not None:
if isinstance(val_data, str):
logging.info(f'Load data from {val_data}.')
val_data = pd.read_csv(val_data, sep='\t', header=None,
error_bad_lines=False, warn_bad_lines=True, quoting=csv.QUOTE_NONE).fillna('')
datasets['val'] = _load_raw_data(val_data, tokenize_text=tokenize_text,
remove_no_label_data=remove_no_label_data)
datasets['val'] = _load_raw_data(
val_data, tokenize_text=tokenize_text, remove_no_label_data=remove_no_label_data)
elif val_size > 0:
datasets['train'], datasets['val'] = train_test_split(
datasets['train'], test_size=val_size, random_state=42)

if test_data is not None:
if isinstance(test_data, str):
logging.info(f'Load data from {test_data}.')
test_data = pd.read_csv(test_data, sep='\t', header=None,
error_bad_lines=False, warn_bad_lines=True, quoting=csv.QUOTE_NONE).fillna('')
datasets['test'] = _load_raw_data(test_data, is_test=True, tokenize_text=tokenize_text,
remove_no_label_data=remove_no_label_data)
datasets['test'] = _load_raw_data(
test_data, is_test=True, tokenize_text=tokenize_text, remove_no_label_data=remove_no_label_data)

if merge_train_val:
datasets['train'] = datasets['train'] + datasets['val']
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
nltk
pandas
pandas>1.3.0
PyYAML
scikit-learn
torch>=1.12.0
torch>=1.13.1
torchmetrics==0.10.3
torchtext>=0.13.0
pytorch-lightning==1.7.7
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers =
packages = find:
install_requires =
nltk
pandas
pandas>1.3.0
PyYAML
scikit-learn
torch>=1.13.1
Expand Down

0 comments on commit 8f499f7

Please sign in to comment.