Getting Started
Stormtrooper is a lightweight Python library for zero and few-shot classification using transformer models. All components are fully scikit-learn compatible thereby making it easier to integrate them into your scikit-learn workflows and pipelines.
Installation
You can install stormtrooper from PyPI.
pip install stormtrooper
From version 0.4.0 you can also use OpenAI models in stormtrooper.
export OPENAI_API_KEY="sk-..."
from stormtrooper import Trooper
model = Trooper("gpt-4")
Usage
To get started you can load a model from HuggingFace Hub.
from stormtrooper import Trooper
class_labels = ["atheism/christianity", "astronomy/space"]
example_texts = [
"God came down to earth to save us.",
"A new nebula was recently discovered in the proximity of the Oort cloud."
]
# Initializing a zero-shot text2text model.
model = Trooper("google/flan-t5-base").fit(None, class_labels)
predictions = model.predict(example_texts)
Zero-shot classification
When you don't have access to labelled data, but need to classify textual data, zero-shot classification is a good option. All models you can use in Stormtrooper are capable of zero shot classification.
You can initialize a zero-shot model by not passing any labelled examples, only an exhaustive list of potential labels to the fit()
method.
model.fit(None, y=["dog", "cat"])
Few-shot classification
Few-shot classification is when a model, along with the labels, has access to a small number of examples for each label. Different models in Stormtrooper have different approaches to utilising these examples.
X = [
"He was barking like hell.",
"Purring on my lap is a most curious creature.",
"Needed to put the pet on leash."
]
y = ["dog", "cat", "dog"]
model.fit(X, y)
Custom Prompts
Models relying on text generation can be used with custom prompts. These might result in better performance than the original generic prompts that come prepackaged with the library. Additionally, chat models can have a system prompt passed along to them. In this example we will use a small LLM and specify a custom prompt template.
from stormtrooper import Trooper
system_prompt = """
You are a pet classifier. When you are presented with a sentence, you recognize which pet the
sentence is about.
You only respond with the brief name of the pet and nothing else.
Please follow the user's instructions as precisely as you can.
"""
prompt = """
Your task will be to classify a sentence into one
of the following classes: {classes}.
{examples}
Classify the following piece of text:
'{X}'
"""
model = Trooper("TinyLlama/TinyLlama-1.1B-Chat-v1.0", prompt=prompt, system_prompt=system_prompt)
model.fit(X, y)
model.predict("Who's a good boy??")
# 'dog'
Prompts get infused with texts {X}
(represents an individual text), class labels {classes}
, and potentially with labelled examples {examples}
.
This happens using Python's str.format()
method in the background.
As such, you have to put templated names in brackets.
Inference on GPU
To run models locally on a GPU, you can use the device
attribute of stormtrooper models.
model = Trooper("all-MiniLM-L6-v2", device="cuda")
Inference on multiple GPUs
You can run a model on multiple devices in order of device priority GPU -> CPU + Ram -> Disk
and on multiple devices by using the device_map
argument.
Note that this only works with text2text and generative models.
model = Trooper("HuggingFaceH4/zephyr-7b-beta", device_map="auto")
Trooper Interface
The Trooper
class wraps all zero and few-shot classifier models in stormtrooper
and automatically detects what type the given model is.
This is determined based on the following order of preference:
API Reference
stormtrooper.trooper.Trooper
Bases: BaseEstimator
, ClassifierMixin
Generic zero-shot, few-shot classifier. Automatically determines the type based on model name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name |
str
|
Name of the base model to use. Could be a model from OpenAI or HuggingFace Hub. |
required |
progress_bar |
bool
|
Indicates whether a progress bar should be displayed during inference. |
True
|
device |
Optional[str]
|
The device the model should run on (only valid for locally run models). |
None
|
device_map |
Optional[str]
|
Device map argument for very large models. |
None
|
prompt |
Optional[str]
|
Prompt to use for promptable models. |
None
|
system_prompt |
Optional[str]
|
System prompt to use for chat models. |
None
|
fuzzy_match |
bool
|
Indicates whether responses should be fuzzy matched to the closest class name, when using a generative model. This is useful, as large language models sometimes misspell words or capitalize them when responding to queries. With fuzzy matching you can rest assured that you will only get responses exactly matching target labels. |
True
|
Source code in stormtrooper/trooper.py
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
|
fit(X, y)
Learns class labels and potential examples.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
X |
Optional[Iterable[str]]
|
Examples to use in few-shot prompts. Pass None, when no examples are to be used. |
required |
y |
Iterable[str]
|
Class labels. |
required |
Source code in stormtrooper/trooper.py
150 151 152 153 154 155 156 157 158 159 160 161 162 |
|
partial_fit(X, y)
Learns class labels and potential examples in a batch.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
X |
Optional[Iterable[str]]
|
Examples to use in few-shot prompts. Pass None, when no examples are to be used. |
required |
y |
Iterable[str]
|
Class labels. |
required |
Source code in stormtrooper/trooper.py
164 165 166 167 168 169 170 171 172 173 174 175 176 |
|
predict(X)
Predicts labels for a set of examples
Parameters:
Name | Type | Description | Default |
---|---|---|---|
X |
Iterable[str]
|
Documents to predict labels for. |
required |
Returns:
Type | Description |
---|---|
ndarray of shape (n_documents,)
|
Labels for documents. |
Source code in stormtrooper/trooper.py
178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
|
predict_proba(X)
Predicts probability of each label. Only available for certain models.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
X |
Iterable[str]
|
Documents to predict labels for. |
required |
Returns:
Type | Description |
---|---|
ndarray of shape (n_documents, n_labels)
|
Label distributions for documents. |
Source code in stormtrooper/trooper.py
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
|