Overview | Tutorials | Examples | Installation | FAQ | How to Cite
- 2020-07-29:
- ktrain v0.19.x is released and now includes support for "traditional" tabular data and explainable AI for tabular predictions. See the tutorial notebook on tabular models for both:
- a classification example (using the Kaggle's Titanic passenger survival prediction dataset)
- a regression example (using the UCI's Adults census dataset for age prediction)
- ktrain v0.19.x is released and now includes support for "traditional" tabular data and explainable AI for tabular predictions. See the tutorial notebook on tabular models for both:
- 2020-07-07:
- ktrain v0.18.x is released and now includes support for TensorFlow 2.2.0. Due to various TensorFlow 2.2.0 bugs, TF 2.2.0 is only installed if Python 3.8 is being used. Otherwise, TensorFlow 2.1.0 is always installed (i.e., on Python 3.6 and 3.7 systems).
- 2020-06-28:
- Hamiz Ahmed published his Medium article: Finetuning BERT using ktrain for Disaster Tweets Classification
- 2020-06-26:
- ktrain v0.17.x is released and includes support for language translation. See the example language translation notebook for more information. (This feature currently requires that PyTorch be installed.)
# Example: Translating Chinese to German
# NOTE: Language Translation uses PyTorch instead of TensorFlow
from ktrain import text
translator = text.Translator(model_name='Helsinki-NLP/opus-mt-ZH-de')
src_text = '''大流行对世界经济造成了严重破坏。但是,截至2020年6月,美国股票市场持续上涨。'''
print(translator.translate(src_text))
# output:
# Die Pandemie hat eine ernste Zerstörung der Weltwirtschaft verursacht.
# Aber bis Juni 2020 stieg der US-Markt weiter an.
- 2020-06-03:
- ktrain v0.16.x is released and includes support for Zero-Shot Learning, where documents can be classified into user-provided topics without any training examples. See the example notebook. (This feature currently requires that PyTorch be installed.)
# Zero-Shot Topic Classification in ktrain (NOTE: Zero-Shot Learning uses PyTorch instead of TensorFlow)
from ktrain import text
zsl = text.ZeroShotClassifier()
topic_strings=['politics', 'elections', 'sports', 'films', 'television']
doc = 'I am unhappy with decisions of the government and will definitely vote in 2020.'
zsl.predict(doc, topic_strings=topic_strings, include_labels=True)
# output:
# [('politics', 0.9829113483428955),
# ('elections', 0.9880988001823425),
# ('sports', 0.00030677582253701985),
# ('films', 0.0008969294722191989),
# ('television', 0.00045271270209923387)]
ktrain is a lightweight wrapper for the deep learning library TensorFlow Keras (and other libraries) to help build, train, and deploy neural networks and other machine learning models. Inspired by ML framework extensions like fastai and ludwig, it is designed to make deep learning and AI more accessible and easier to apply for both newcomers and experienced practitioners. With only a few lines of code, ktrain allows you to easily and quickly:
-
employ fast, accurate, and easy-to-use pre-canned models for
text
,vision
,graph
, andtabular
data:text
data:- Text Classification: BERT, DistilBERT, NBSVM, fastText, and other models [example notebook]
- Text Regression: BERT, DistilBERT, Embedding-based linear text regression, fastText, and other models [example notebook]
- Sequence Labeling (NER): Bidirectional LSTM with optional CRF layer and various embedding schemes such as pretrained BERT and fasttext word embeddings and character embeddings [example notebook]
- Ready-to-Use NER models for English, Chinese, and Russian with no training required [example notebook]
- Sentence Pair Classification for tasks like paraphrase detection [example notebook]
- Unsupervised Topic Modeling with LDA [example notebook]
- Document Similarity with One-Class Learning: given some documents of interest, find and score new documents that are semantically similar to them using One-Class Text Classification [example notebook]
- Document Recommendation Engine: given text from a sample document, recommend documents that are thematically-related to it from a larger corpus [example notebook]
- Text Summarization: summarize long documents with a pretrained BART model - no training required [example notebook]
- Open-Domain Question-Answering: ask a large text corpus questions and receive exact answers [example notebook]
- Zero-Shot Learning: classify documents into user-provided topics without training examples [example notebook]
- Language Translation: translate text from one language to another [example notebook]
vision
data:- image classification (e.g., ResNet, Wide ResNet, Inception) [example notebook]
- image regression for predicting numerical targets from photos (e.g., age prediction) [example notebook]
graph
data:- node classification with graph neural networks (GraphSAGE) [example notebook]
- link prediction with graph neural networks (GraphSAGE) [example notebook]
tabular
data:- tabular classification using the Titanic dataset [example notebook]
- tabular regression using Census data [example notebook]
-
estimate an optimal learning rate for your model given your data using a Learning Rate Finder
-
utilize learning rate schedules such as the triangular policy, the 1cycle policy, and SGDR to effectively minimize loss and improve generalization
-
build text classifiers for any language (e.g., Chinese Sentiment Analysis with BERT, Arabic Sentiment Analysis with NBSVM)
-
easily train NER models for any language (e.g., Dutch NER )
-
load and preprocess text and image data from a variety of formats
-
inspect data points that were misclassified and provide explanations to help improve your model
-
leverage a simple prediction API for saving and deploying both models and data-preprocessing steps to make predictions on new raw data
Please see the following tutorial notebooks for a guide on how to use ktrain on your projects:
- Tutorial 1: Introduction
- Tutorial 2: Tuning Learning Rates
- Tutorial 3: Image Classification
- Tutorial 4: Text Classification
- Tutorial 5: Learning from Unlabeled Text Data
- Tutorial 6: Text Sequence Tagging for Named Entity Recognition
- Tutorial 7: Graph Node Classification with Graph Neural Networks
- Tutorial 8: Tabular Classification and Regression
- Tutorial A1: Additional tricks, which covers topics such as previewing data augmentation schemes, inspecting intermediate output of Keras models for debugging, setting global weight decay, and use of built-in and custom callbacks.
- Tutorial A2: Explaining Predictions and Misclassifications
- Tutorial A3: Text Classification with Hugging Face Transformers
- Tutorial A4: Using Custom Data Formats and Models: Text Regression with Extra Regressors
Some blog tutorials about ktrain are shown below:
ktrain: A Lightweight Wrapper for Keras to Help Train Neural Networks
Text Classification with Hugging Face Transformers in TensorFlow 2 (Without Tears)
Build an Open-Domain Question-Answering System With BERT in 3 Lines of Code
Finetuning BERT using ktrain for Disaster Tweets Classification by Hamiz Ahmed
Tasks such as text classification and image classification can be accomplished easily with only a few lines of code.
Example: Text Classification of IMDb Movie Reviews Using BERT
import ktrain
from ktrain import text as txt
# load data
(x_train, y_train), (x_test, y_test), preproc = txt.texts_from_folder('data/aclImdb', maxlen=500,
preprocess_mode='bert',
train_test_names=['train', 'test'],
classes=['pos', 'neg'])
# load model
model = txt.text_classifier('bert', (x_train, y_train), preproc=preproc)
# wrap model and data in ktrain.Learner object
learner = ktrain.get_learner(model,
train_data=(x_train, y_train),
val_data=(x_test, y_test),
batch_size=6)
# find good learning rate
learner.lr_find() # briefly simulate training to find good learning rate
learner.lr_plot() # visually identify best learning rate
# train using 1cycle learning rate schedule for 3 epochs
learner.fit_onecycle(2e-5, 3)
Example: Classifying Images of Dogs and Cats Using a Pretrained ResNet50 model
import ktrain
from ktrain import vision as vis
# load data
(train_data, val_data, preproc) = vis.images_from_folder(
datadir='data/dogscats',
data_aug = vis.get_data_aug(horizontal_flip=True),
train_test_names=['train', 'valid'],
target_size=(224,224), color_mode='rgb')
# load model
model = vis.image_classifier('pretrained_resnet50', train_data, val_data, freeze_layers=80)
# wrap model and data in ktrain.Learner object
learner = ktrain.get_learner(model=model, train_data=train_data, val_data=val_data,
workers=8, use_multiprocessing=False, batch_size=64)
# find good learning rate
learner.lr_find() # briefly simulate training to find good learning rate
learner.lr_plot() # visually identify best learning rate
# train using triangular policy with ModelCheckpoint and implicit ReduceLROnPlateau and EarlyStopping
learner.autofit(1e-4, checkpoint_folder='/tmp/saved_weights')
Example: Sequence Labeling for Named Entity Recognition using a randomly initialized Bidirectional LSTM CRF model
import ktrain
from ktrain import text as txt
# load data
(trn, val, preproc) = txt.entities_from_txt('data/ner_dataset.csv',
sentence_column='Sentence #',
word_column='Word',
tag_column='Tag',
data_format='gmb',
use_char=True) # enable character embeddings
# load model
model = txt.sequence_tagger('bilstm-crf', preproc)
# wrap model and data in ktrain.Learner object
learner = ktrain.get_learner(model, train_data=trn, val_data=val)
# conventional training for 1 epoch using a learning rate of 0.001 (Keras default for Adam optmizer)
learner.fit(1e-3, 1)
Example: Node Classification on Cora Citation Graph using a GraphSAGE model
import ktrain
from ktrain import graph as gr
# load data with supervision ratio of 10%
(trn, val, preproc) = gr.graph_nodes_from_csv(
'cora.content', # node attributes/labels
'cora.cites', # edge list
sample_size=20,
holdout_pct=None,
holdout_for_inductive=False,
train_pct=0.1, sep='\t')
# load model
model=gr.graph_node_classifier('graphsage', trn)
# wrap model and data in ktrain.Learner object
learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=64)
# find good learning rate
learner.lr_find(max_epochs=100) # briefly simulate training to find good learning rate
learner.lr_plot() # visually identify best learning rate
# train using triangular policy with ModelCheckpoint and implicit ReduceLROnPlateau and EarlyStopping
learner.autofit(0.01, checkpoint_folder='/tmp/saved_weights')
Example: Text Classification with Hugging Face Transformers on 20 Newsgroups Dataset Using DistilBERT
# load text data
categories = ['alt.atheism', 'soc.religion.christian','comp.graphics', 'sci.med']
from sklearn.datasets import fetch_20newsgroups
train_b = fetch_20newsgroups(subset='train', categories=categories, shuffle=True)
test_b = fetch_20newsgroups(subset='test',categories=categories, shuffle=True)
(x_train, y_train) = (train_b.data, train_b.target)
(x_test, y_test) = (test_b.data, test_b.target)
# build, train, and validate model (Transformer is wrapper around transformers library)
import ktrain
from ktrain import text
MODEL_NAME = 'distilbert-base-uncased'
t = text.Transformer(MODEL_NAME, maxlen=500, class_names=train_b.target_names)
trn = t.preprocess_train(x_train, y_train)
val = t.preprocess_test(x_test, y_test)
model = t.get_classifier()
learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=6)
learner.fit_onecycle(5e-5, 4)
learner.validate(class_names=t.get_classes()) # class_names must be string values
# Output from learner.validate()
# precision recall f1-score support
#
# alt.atheism 0.92 0.93 0.93 319
# comp.graphics 0.97 0.97 0.97 389
# sci.med 0.97 0.95 0.96 396
#soc.religion.christian 0.96 0.96 0.96 398
#
# accuracy 0.96 1502
# macro avg 0.95 0.96 0.95 1502
# weighted avg 0.96 0.96 0.96 1502
Example: NER With BioBERT Embeddings
# NER with BioBERT embeddings
import ktrain
from ktrain import text as txt
x_train= [['IL-2', 'responsiveness', 'requires', 'three', 'distinct', 'elements', 'within', 'the', 'enhancer', '.'], ...]
y_train=[['B-protein', 'O', 'O', 'O', 'O', 'B-DNA', 'O', 'O', 'B-DNA', 'O'], ...]
(trn, val, preproc) = txt.entities_from_array(x_train, y_train)
model = txt.sequence_tagger('bilstm-bert', preproc, bert_model='monologg/biobert_v1.1_pubmed')
learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=128)
learner.fit(0.01, 1, cycle_len=5)
Using ktrain on Google Colab? See these Colab examples:
- a simple demo of Multiclass Text Classification with BERT
- a simple demo of Multiclass Text Classification with Hugging Face Transformers
- image classification with Cats vs. Dogs
Additional examples can be found here.
-
Make sure pip is up-to-date with:
pip3 install -U pip
-
Install ktrain:
pip3 install ktrain
Some important things to note about installation:
- If using ktrain on a local machine with a GPU (versus Google Colab, for example), you'll need to install GPU support for TensorFlow 2.
- ktrain currently uses TensorFlow 2.1.0 or 2.2.0, which will be installed automatically when installing ktrain. TensorFlow 2.1.0 will be installed as a dependency on Python 3.6 and 3.7 systems. TensorFlow 2.2.0 will be installed only if using Python 3.8 (as TF 2.1.0 does not support Python 3.8). On systems where Python 3.8 is the default (e.g., Ubuntu 20.04), we recommend installing and using Python 3.6/3.7 and TensorFlow 2.1.0 with ktrain due to unresolved bugs in versions of TensorFlow >= 2.2.0.
- Since some ktrain dependencies have not yet been migrated to
tf.keras
in TensorFlow 2 (or may have other issues), ktrain is temporarily using forked versions of some libraries. Specifically, ktrain uses forked versions of theeli5
andstellargraph
libraries. If not installed, ktrain will complain when a method or function needing either of these libraries is invoked. To install these forked versions, you can do the following:
pip3 install git+https://github.com/amaiya/eli5@tfkeras_0_10_1
pip3 install git+https://github.com/amaiya/stellargraph@no_tf_dep_082
This code was tested on Ubuntu 18.04 LTS using TensorFlow 2.1.0 and 2.2.0 and Python 3.6.9.
Please cite the following paper when using ktrain:
@article{maiya2020ktrain,
title={ktrain: A Low-Code Library for Augmented Machine Learning},
author={Arun S. Maiya},
journal={arXiv},
year={2020},
volume={arXiv:2004.10703 [cs.LG]}
}
Creator: Arun S. Maiya
Email: arun [at] maiya [dot] net