Skip to content

Getting Started

Stormtrooper is a lightweight Python library for zero and few-shot classification using transformer models. All components are fully scikit-learn compatible thereby making it easier to integrate them into your scikit-learn workflows and pipelines.

Installation

You can install stormtrooper from PyPI.

pip install stormtrooper

From version 0.4.0 you can also use OpenAI models in stormtrooper.

export OPENAI_API_KEY="sk-..."
from stormtrooper import Trooper

model = Trooper("gpt-4")

Usage

To get started you can load a model from HuggingFace Hub.

from stormtrooper import Trooper

class_labels = ["atheism/christianity", "astronomy/space"]
example_texts = [
    "God came down to earth to save us.",
    "A new nebula was recently discovered in the proximity of the Oort cloud."
]

# Initializing a zero-shot text2text model.
model = Trooper("google/flan-t5-base").fit(None, class_labels)
predictions = model.predict(example_texts)

Zero-shot classification

When you don't have access to labelled data, but need to classify textual data, zero-shot classification is a good option. All models you can use in Stormtrooper are capable of zero shot classification.

You can initialize a zero-shot model by not passing any labelled examples, only an exhaustive list of potential labels to the fit() method.

model.fit(None, y=["dog", "cat"])

Few-shot classification

Few-shot classification is when a model, along with the labels, has access to a small number of examples for each label. Different models in Stormtrooper have different approaches to utilising these examples.

X = [
  "He was barking like hell.",
  "Purring on my lap is a most curious creature.",
  "Needed to put the pet on leash."
]
y = ["dog", "cat", "dog"]
model.fit(X, y)

Custom Prompts

Models relying on text generation can be used with custom prompts. These might result in better performance than the original generic prompts that come prepackaged with the library. Additionally, chat models can have a system prompt passed along to them. In this example we will use a small LLM and specify a custom prompt template.

from stormtrooper import Trooper

system_prompt = """
You are a pet classifier. When you are presented with a sentence, you recognize which pet the
sentence is about.
You only respond with the brief name of the pet and nothing else.
Please follow the user's instructions as precisely as you can.
"""

prompt = """
Your task will be to classify a sentence into one
of the following classes: {classes}.
{examples}
Classify the following piece of text:
'{X}'
"""

model = Trooper("TinyLlama/TinyLlama-1.1B-Chat-v1.0", prompt=prompt, system_prompt=system_prompt)
model.fit(X, y)

model.predict("Who's a good boy??")
# 'dog'

Prompts get infused with texts {X} (represents an individual text), class labels {classes}, and potentially with labelled examples {examples}. This happens using Python's str.format() method in the background. As such, you have to put templated names in brackets.

Inference on GPU

To run models locally on a GPU, you can use the device attribute of stormtrooper models.

model = Trooper("all-MiniLM-L6-v2", device="cuda")

Inference on multiple GPUs

You can run a model on multiple devices in order of device priority GPU -> CPU + Ram -> Disk and on multiple devices by using the device_map argument. Note that this only works with text2text and generative models.

model = Trooper("HuggingFaceH4/zephyr-7b-beta", device_map="auto")

Trooper Interface

The Trooper class wraps all zero and few-shot classifier models in stormtrooper and automatically detects what type the given model is.

This is determined based on the following order of preference:

This is how models get loaded with the Trooper interface

API Reference

stormtrooper.trooper.Trooper

Bases: BaseEstimator, ClassifierMixin

Generic zero-shot, few-shot classifier. Automatically determines the type based on model name.

Parameters:

Name Type Description Default
model_name str

Name of the base model to use. Could be a model from OpenAI or HuggingFace Hub.

required
progress_bar bool

Indicates whether a progress bar should be displayed during inference.

True
device Optional[str]

The device the model should run on (only valid for locally run models).

None
device_map Optional[str]

Device map argument for very large models.

None
prompt Optional[str]

Prompt to use for promptable models.

None
system_prompt Optional[str]

System prompt to use for chat models.

None
fuzzy_match bool

Indicates whether responses should be fuzzy matched to the closest class name, when using a generative model. This is useful, as large language models sometimes misspell words or capitalize them when responding to queries. With fuzzy matching you can rest assured that you will only get responses exactly matching target labels.

True
Source code in stormtrooper/trooper.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
class Trooper(BaseEstimator, ClassifierMixin):
    """Generic zero-shot, few-shot classifier.
    Automatically determines the type based on model name.

    Parameters
    ----------
    model_name: str
        Name of the base model to use.
        Could be a model from OpenAI or HuggingFace Hub.
    progress_bar: bool, default True
        Indicates whether a progress bar should be displayed during inference.
    device: str, default None
        The device the model should run on (only valid for locally run models).
    device_map: str, default None
        Device map argument for very large models.
    prompt: str, default None
        Prompt to use for promptable models.
    system_prompt: str, default None
        System prompt to use for chat models.
    fuzzy_match: bool, default True
        Indicates whether responses should be fuzzy matched to the closest class name,
        when using a generative model.
        This is useful, as large language models sometimes misspell words or capitalize them
        when responding to queries.
        With fuzzy matching you can rest assured that you will only get responses exactly
        matching target labels.
    """

    def __init__(
        self,
        model_name: str,
        *,
        progress_bar: bool = True,
        device: Optional[str] = None,
        device_map: Optional[str] = None,
        prompt: Optional[str] = None,
        system_prompt: Optional[str] = None,
        fuzzy_match: bool = True,
    ):
        self.model_name = model_name
        self.model_type = get_model_type(model_name)
        self.progress_bar = progress_bar
        self.device_map = device_map
        self.device = device
        self.prompt = prompt
        self.system_prompt = system_prompt
        self.fuzzy_match = fuzzy_match
        model_kwargs = dict(model_name=model_name)
        if self.model_type in ["generative", "openai", "text2text"]:
            model_kwargs["fuzzy_match"] = self.fuzzy_match
            if self.prompt is not None:
                model_kwargs["prompt"] = self.prompt
            if (
                self.model_type == "text2text"
            ) and self.system_prompt is not None:
                model_kwargs["system_prompt"] = self.system_prompt
        if self.model_type in ["generative", "text2text", "nli", "setfit"]:
            model_kwargs["device"] = self.device
        if self.model_type in ["generative", "text2text"]:
            model_kwargs["device_map"] = self.device_map
        if self.model_type in ["nli", "generative", "text2text", "openai"]:
            model_kwargs["progress_bar"] = self.progress_bar
        self.model = model_type_to_cls[self.model_type](**model_kwargs)

    def fit(self, X: Optional[Iterable[str]], y: Iterable[str]):
        """Learns class labels and potential examples.

        Parameters
        ----------
        X: iterable of str
            Examples to use in few-shot prompts.
            Pass None, when no examples are to be used.
        y: iterable of str
            Class labels.
        """
        self.model.fit(X, y)
        return self

    def partial_fit(self, X: Optional[Iterable[str]], y: Iterable[str]):
        """Learns class labels and potential examples in a batch.

        Parameters
        ----------
        X: iterable of str
            Examples to use in few-shot prompts.
            Pass None, when no examples are to be used.
        y: iterable of str
            Class labels.
        """
        self.model.partial_fit(X, y)
        return self

    def predict(self, X: Iterable[str]) -> np.ndarray:
        """Predicts labels for a set of examples

        Parameters
        ----------
        X: iterable of str
            Documents to predict labels for.

        Returns
        -------
        ndarray of shape (n_documents,)
            Labels for documents.
        """
        return self.model.predict(X)

    def predict_proba(self, X: Iterable[str]) -> np.ndarray:
        """Predicts probability of each label.
        Only available for certain models.

        Parameters
        ----------
        X: iterable of str
            Documents to predict labels for.

        Returns
        -------
        ndarray of shape (n_documents, n_labels)
            Label distributions for documents.
        """
        return self.model.predict_proba(X)

    @property
    def classes_(self):
        return self.model.classes

fit(X, y)

Learns class labels and potential examples.

Parameters:

Name Type Description Default
X Optional[Iterable[str]]

Examples to use in few-shot prompts. Pass None, when no examples are to be used.

required
y Iterable[str]

Class labels.

required
Source code in stormtrooper/trooper.py
150
151
152
153
154
155
156
157
158
159
160
161
162
def fit(self, X: Optional[Iterable[str]], y: Iterable[str]):
    """Learns class labels and potential examples.

    Parameters
    ----------
    X: iterable of str
        Examples to use in few-shot prompts.
        Pass None, when no examples are to be used.
    y: iterable of str
        Class labels.
    """
    self.model.fit(X, y)
    return self

partial_fit(X, y)

Learns class labels and potential examples in a batch.

Parameters:

Name Type Description Default
X Optional[Iterable[str]]

Examples to use in few-shot prompts. Pass None, when no examples are to be used.

required
y Iterable[str]

Class labels.

required
Source code in stormtrooper/trooper.py
164
165
166
167
168
169
170
171
172
173
174
175
176
def partial_fit(self, X: Optional[Iterable[str]], y: Iterable[str]):
    """Learns class labels and potential examples in a batch.

    Parameters
    ----------
    X: iterable of str
        Examples to use in few-shot prompts.
        Pass None, when no examples are to be used.
    y: iterable of str
        Class labels.
    """
    self.model.partial_fit(X, y)
    return self

predict(X)

Predicts labels for a set of examples

Parameters:

Name Type Description Default
X Iterable[str]

Documents to predict labels for.

required

Returns:

Type Description
ndarray of shape (n_documents,)

Labels for documents.

Source code in stormtrooper/trooper.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def predict(self, X: Iterable[str]) -> np.ndarray:
    """Predicts labels for a set of examples

    Parameters
    ----------
    X: iterable of str
        Documents to predict labels for.

    Returns
    -------
    ndarray of shape (n_documents,)
        Labels for documents.
    """
    return self.model.predict(X)

predict_proba(X)

Predicts probability of each label. Only available for certain models.

Parameters:

Name Type Description Default
X Iterable[str]

Documents to predict labels for.

required

Returns:

Type Description
ndarray of shape (n_documents, n_labels)

Label distributions for documents.

Source code in stormtrooper/trooper.py
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def predict_proba(self, X: Iterable[str]) -> np.ndarray:
    """Predicts probability of each label.
    Only available for certain models.

    Parameters
    ----------
    X: iterable of str
        Documents to predict labels for.

    Returns
    -------
    ndarray of shape (n_documents, n_labels)
        Label distributions for documents.
    """
    return self.model.predict_proba(X)