Skip to content

Models

Stormtrooper supports a wide variety of model types for zero and few-shot classification. This page includes a general overview of how the different methods approach this task.

Text Generation

LLMs are a relatively easy to utilise for zero and few-shot classification, as they contain a lot of general language-based knowledge and can provide free-form answers. Models that generate text typically have to be prompted. One has to pass free-form text instructions to a model, to which it can respond with a (hopefully) appropriate answer.

Instruction Models

Models used in chatbots and alike are typically instruction-finetuned generatively pretrained transformer models. These models take a string of messages and generate a new message at the end by predicting next-token probabilities.

These models also typically take a system prompt a base prompt that tells the model what persona it should have and how it should behave when presented with instructions.

You can use instruction models from both HuggingFace Hub, but also OpenAI in stromtrooper.

from stormtrooper import Trooper

# Model from HuggingFace:
model = Trooper("HuggingFaceH4/zephyr-7b-beta")

# OpenAI model
model = Trooper("gpt-4")

stormtrooper.openai.OpenAIClassifier

Bases: ChatClassifier

Use OpenAI's models for zero and few-shot text classification.

Parameters:

Name Type Description Default
model_name str

Name of the OpenAI chat model to use.

'gpt-3.5-turbo'
temperature float

Temperature for text generation. Higher temperature results in more diverse answers.

1.0
prompt str

Prompt template to use for each text.

default_prompt
system_prompt str

System prompt for the model.

default_system_prompt
max_new_tokens int

Maximum number of new tokens to generate.

256
fuzzy_match bool

Indicates whether responses should be fuzzy-matched to closest learned label.

True
progress_bar bool

Inidicates whether a progress bar should be desplayed when obtaining results.

True
Source code in stormtrooper/openai.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class OpenAIClassifier(ChatClassifier):
    """Use OpenAI's models for zero and few-shot text classification.

    Parameters
    ----------
    model_name: str, default "gpt-3.5-turbo"
        Name of the OpenAI chat model to use.
    temperature: float, default 1.0
        Temperature for text generation.
        Higher temperature results in more diverse answers.
    prompt: str
        Prompt template to use for each text.
    system_prompt: str
        System prompt for the model.
    max_new_tokens: int, default 256
        Maximum number of new tokens to generate.
    fuzzy_match: bool, default True
        Indicates whether responses should be fuzzy-matched to closest learned label.
    progress_bar: bool, default True
        Inidicates whether a progress bar should be desplayed when obtaining results.
    """

    def __init__(
        self,
        model_name: str = "gpt-3.5-turbo",
        temperature: float = 1.0,
        prompt: str = default_prompt,
        system_prompt: str = default_system_prompt,
        max_new_tokens: int = 256,
        fuzzy_match: bool = True,
        progress_bar: bool = True,
    ):
        self.model_name = model_name
        self.prompt = prompt
        self.system_prompt = system_prompt
        self.temperature = temperature
        self.classes_ = None
        self.max_new_tokens = max_new_tokens
        self.fuzzy_match = fuzzy_match
        try:
            openai.api_key = os.environ["OPENAI_API_KEY"]
            openai.organization = os.environ.get("OPENAI_ORG")
            client = openai.OpenAI(api_key=openai.api_key)
            valid_model_ids = [model.id for model in client.models.list()]
            if model_name not in valid_model_ids:
                raise ValueError(
                    f"{model_name} is not a valid model ID for OpenAI."
                )
        except KeyError as e:
            raise KeyError(
                "Environment variable OPENAI_API_KEY not specified."
            ) from e
        self.client = openai.AsyncOpenAI()

    async def predict_one_async(self, text: str) -> str:
        messages = self.generate_messages(text)
        response = await self.client.chat.completions.create(
            messages=messages,
            model=self.model_name,
            temperature=self.temperature,
            max_tokens=self.max_new_tokens,
        )
        return response.choices[0].message.content

    def predict_one(self, text: str) -> str:
        return asyncio.run(self.predict_one_async(text))

    async def predict_async(self, X: Iterable[str]) -> np.ndarray:
        if self.classes_ is None:
            raise NotFittedError(
                "Class labels have not been collected yet, please fit."
            )
        if self.progress_bar:
            X = tqdm(X)
        res = await asyncio.gather(*[self.predict_one_async(x) for x in X])
        return np.array(res)

    def predict(self, X: Iterable[str]) -> np.ndarray:
        """Predicts most probable class label for given texts.

        Parameters
        ----------
        X: iterable of str
            Texts to label.

        Returns
        -------
        array of shape (n_texts)
            Array of string class labels.
        """
        return asyncio.run(self.predict_async(X))

predict(X)

Predicts most probable class label for given texts.

Parameters:

Name Type Description Default
X Iterable[str]

Texts to label.

required

Returns:

Type Description
array of shape (n_texts)

Array of string class labels.

Source code in stormtrooper/openai.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def predict(self, X: Iterable[str]) -> np.ndarray:
    """Predicts most probable class label for given texts.

    Parameters
    ----------
    X: iterable of str
        Texts to label.

    Returns
    -------
    array of shape (n_texts)
        Array of string class labels.
    """
    return asyncio.run(self.predict_async(X))

stormtrooper.generative.GenerativeClassifier

Bases: ChatClassifier

Scikit-learn compatible zero shot classification with generative language models.

Parameters:

Name Type Description Default
model_name str

Generative instruct model from HuggingFace.

'HuggingFaceH4/zephyr-7b-beta'
prompt str

You can specify the prompt which will be used to prompt the model. Use placeholders to indicate where the class labels and the data should be placed in the prompt.

default_prompt
system_prompt str

System prompt for the model.

default_system_prompt
max_new_tokens int

Maximum number of tokens the model should generate.

256
fuzzy_match bool

Indicates whether the output lables should be fuzzy matched to the learnt class labels. This is useful when the model isn't giving specific enough answers.

True
progress_bar bool

Indicates whether a progress bar should be shown.

True
device Optional[str]

Indicates which device should be used for classification. Models are by default run on CPU.

None
device_map Optional[str]

Device map argument for very large models.

None

Attributes:

Name Type Description
classes_ array of str

Class names learned from the labels.

Source code in stormtrooper/generative.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class GenerativeClassifier(ChatClassifier):
    """Scikit-learn compatible zero shot classification
    with generative language models.

    Parameters
    ----------
    model_name: str, default 'HuggingFaceH4/zephyr-7b-beta'
        Generative instruct model from HuggingFace.
    prompt: str, optional
        You can specify the prompt which will be used to prompt the model.
        Use placeholders to indicate where the class labels and the
        data should be placed in the prompt.
    system_prompt: str, optional
        System prompt for the model.
    max_new_tokens: int, default 256
        Maximum number of tokens the model should generate.
    fuzzy_match: bool, default True
        Indicates whether the output lables should be fuzzy matched
        to the learnt class labels.
        This is useful when the model isn't giving specific enough answers.
    progress_bar: bool, default True
        Indicates whether a progress bar should be shown.
    device: str, default None
        Indicates which device should be used for classification.
        Models are by default run on CPU.
    device_map: str, default None
        Device map argument for very large models.

    Attributes
    ----------
    classes_: array of str
        Class names learned from the labels.
    """

    def __init__(
        self,
        model_name: str = "HuggingFaceH4/zephyr-7b-beta",
        prompt: str = default_prompt,
        system_prompt: str = default_system_prompt,
        max_new_tokens: int = 256,
        fuzzy_match: bool = True,
        progress_bar: bool = True,
        device: Optional[str] = None,
        device_map: Optional[str] = None,
    ):
        self.model_name = model_name
        self.prompt = prompt
        self.device = device
        self.system_prompt = system_prompt
        self.device_map = device_map
        self.pipeline = pipeline(
            "text-generation",
            self.model_name,
            device=self.device,
            device_map=self.device_map,
        )
        self.classes_ = None
        self.max_new_tokens = max_new_tokens
        self.fuzzy_match = fuzzy_match
        self.progress_bar = progress_bar

    def predict_one(self, text: str) -> np.ndarray:
        messages = self.generate_messages(text)
        response = self.pipeline(messages, max_new_tokens=self.max_new_tokens)[
            0
        ]["generated_text"][-1]
        label = response["content"]
        return label

Text2Text Models

Text2Text models not only generate text, but are trained to predict a sequence of text based on a sequence of incoming text. Input text gets encoded into a low-dimensional latent space, and then this latent representation is used to generate an appropriate response, similar to an AutoEncoder.

Text2Text models are typically smaller and faster than fully generative models, but also less performant.

from stormtrooper import Trooper

model = Trooper("google/flan-t5-small")

stormtrooper.text2text.Text2TextClassifier

Bases: ChatClassifier

Zero and few-shot classification with text2text language models.

Parameters:

Name Type Description Default
model_name str

Text2text model from HuggingFace.

'google/flan-t5-base'
prompt str

You can specify the prompt which will be used to prompt the model. Use placeholders to indicate where the class labels and the data should be placed in the prompt.

default_prompt
max_new_tokens int

Maximum number of tokens the model should generate.

256
fuzzy_match bool

Indicates whether the output lables should be fuzzy matched to the learnt class labels. This is useful when the model isn't giving specific enough answers.

True
progress_bar bool

Indicates whether a progress bar should be shown.

True
device Optional[str]

Indicates which device should be used for classification. Models are by default run on CPU.

None
device_map Optional[str]

Device map argument for very large models.

None

Attributes:

Name Type Description
classes_ array of str

Class names learned from the labels.

Source code in stormtrooper/text2text.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class Text2TextClassifier(ChatClassifier):
    """Zero and few-shot classification
    with text2text language models.

    Parameters
    ----------
    model_name: str, default 'google/flan-t5-base'
        Text2text model from HuggingFace.
    prompt: str, optional
        You can specify the prompt which will be used to prompt the model.
        Use placeholders to indicate where the class labels and the
        data should be placed in the prompt.
    max_new_tokens: int, default 256
        Maximum number of tokens the model should generate.
    fuzzy_match: bool, default True
        Indicates whether the output lables should be fuzzy matched
        to the learnt class labels.
        This is useful when the model isn't giving specific enough answers.
    progress_bar: bool, default True
        Indicates whether a progress bar should be shown.
    device: str, default None
        Indicates which device should be used for classification.
        Models are by default run on CPU.
    device_map: str, default None
        Device map argument for very large models.


    Attributes
    ----------
    classes_: array of str
        Class names learned from the labels.
    """

    def __init__(
        self,
        model_name: str = "google/flan-t5-base",
        prompt: str = default_prompt,
        max_new_tokens: int = 256,
        fuzzy_match: bool = True,
        progress_bar: bool = True,
        device: Optional[str] = None,
        device_map: Optional[str] = None,
    ):
        self.model_name = model_name
        self.prompt = prompt
        self.device = device
        self.device_map = device_map
        self.pipeline = pipeline(
            "text2text-generation",
            model=model_name,
            device=device,
            device_map=device_map,
        )
        self.classes_ = None
        self.progress_bar = progress_bar
        self.max_new_tokens = max_new_tokens
        self.fuzzy_match = fuzzy_match

    def predict_one(self, text: str) -> str:
        prompt = self.get_user_prompt(text)
        response = self.pipeline(prompt)
        return response[0]["generated_text"]

Sentence Transformers + SetFit

Schematic overview of SetFit

SetFit is a commonly employed trick for training classifiers on a low number of labelled datapoints. It involves:

  1. Finetuning a sentence encoder model using contrastive loss, where positive pairs are the examples that belong in the same category, and negative pairs are documents belonging to different classes.
  2. Training a classification head on the finetuned embeddings.

When you load any encoder-style model in Stormtrooper, they are automatically converted into a SetFit model.

from stormtrooper import Trooper

model = Trooper("all-MiniLM-L6-v2")

stormtrooper.set_fit.SetFitClassifier

Bases: BaseEstimator, ClassifierMixin

Zero and few-shot classifier using the SetFit technique with encoder models.

Parameters:

Name Type Description Default
model_name str

Name of the encoder model.

'sentence-transformers/all-MiniLM-L6-v2'
device str

Device to train and run the model on.

'cpu'
classification_head Optional[ClassifierMixin]

Classifier to use as the last step. Defaults to Logistic Regression when not specified.

None
n_epochs int

Number of trainig epochs.

10
batch_size int

Batch size to use during training.

32
sample_size int

Number of training samples to generate (only zero-shot)

8
template_sentence str

Template sentence for synthetic samples (only zero-shot)

'This sentence is {label}.'
random_state int

Seed to use for stochastic training.

42
Source code in stormtrooper/set_fit.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class SetFitClassifier(BaseEstimator, ClassifierMixin):
    """Zero and few-shot classifier using the SetFit technique with encoder models.

    Parameters
    ----------
    model_name: str, default 'sentence-transformers/all-MiniLM-L6-v2'
        Name of the encoder model.
    device: str, default 'cpu'
        Device to train and run the model on.
    classification_head: ClassifierMixin, default None
        Classifier to use as the last step.
        Defaults to Logistic Regression when not specified.
    n_epochs: int, default 10
        Number of trainig epochs.
    batch_size: int, default 8
        Batch size to use during training.
    sample_size: int, default 8
        Number of  training samples to generate (only zero-shot)
    template_sentence: str, default "This sentence is {label}"
        Template sentence for synthetic samples (only zero-shot)
    random_state: int, default 42
        Seed to use for stochastic training.
    """

    def __init__(
        self,
        model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
        classification_head: Optional[ClassifierMixin] = None,
        device: str = "cpu",
        n_epochs: int = 10,
        batch_size: int = 32,
        sample_size: int = 8,
        template_sentence: str = "This sentence is {label}.",
        random_state: int = 42,
    ):
        self.model_name = model_name
        self.classes_ = None
        self.template_sentence = template_sentence
        self.random_state = random_state
        self.encoder = SentenceTransformer(model_name, device=device)
        if classification_head is None:
            self.classification_head = LogisticRegression()
        else:
            self.classification_head = classification_head
        self.trainer = None
        self.sample_size = sample_size
        self.device = device
        self.n_epochs = n_epochs
        self.batch_size = batch_size

    def fit(self, X: Optional[Iterable[str]], y: Iterable[str]):
        if X is not None:
            self.examples_ = dict()
            for text, label in zip(X, y):
                if label not in self.examples_:
                    self.examples_[label] = []
                self.examples_[label].append(text)
        if X is None:
            X, y = generate_synthetic_samples(
                y,
                n_sample_per_label=self.sample_size,
                template_sentence=self.template_sentence,
            )
        self.encoder = finetune_contrastive(
            self.encoder,
            X,
            y,
            n_epochs=self.n_epochs,
            seed=self.random_state,
        )
        X = list(X)
        X_embeddings = self.encoder.encode(X)
        self.classification_head.fit(X_embeddings, y)
        self.classes_ = self.classification_head.classes_
        self.n_classes = len(self.classes_)
        return self

    def predict(self, X: Iterable[str]) -> np.ndarray:
        if getattr(self, "classes_", None) is None:
            raise NotFittedError("You need to fit the model before running inference.")
        return self.classification_head.predict(self.encoder.encode(X))

    def predict_proba(self, X: Iterable[str]) -> np.ndarray:
        if getattr(self, "classes_", None) is None:
            raise NotFittedError("You need to fit the model before running inference.")
        return self.classification_head.predict_proba(self.encoder.encode(X))

Natural Language Inference

Natural language inference entails classifying pairs of texts based on whether they are congruent with each other. Models finetuned for NLI can also be utilised for zero-shot classification.

from stormtrooper import Trooper

model = Trooper("facebook/bart-large-mnli").fit(None, ["dog", "cat"])
model.predict_proba(["He was barking like hell"])
# array([[0.95, 0.005]])

stormtrooper.zero_shot.ZeroShotClassifier

Bases: BaseEstimator, TransformerMixin, ClassifierMixin

Scikit-learn compatible zero shot classification with HuggingFace Transformers.

Parameters:

Name Type Description Default
model_name str

Zero-shot model to load from HuggingFace.

'facebook/bart-large-mnli'
progress_bar bool

Indicates whether a progress bar should be shown.

True
device str

Indicates which device should be used for classification. Models are by default run on CPU.

'cpu'

Attributes:

Name Type Description
classes_ array of str

Class names learned from the labels.

Source code in stormtrooper/zero_shot.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
class ZeroShotClassifier(BaseEstimator, TransformerMixin, ClassifierMixin):
    """Scikit-learn compatible zero shot classification
    with HuggingFace Transformers.

    Parameters
    ----------
    model_name: str, default 'facebook/bart-large-mnli'
        Zero-shot model to load from HuggingFace.
    progress_bar: bool, default True
        Indicates whether a progress bar should be shown.
    device: str, default 'cpu'
        Indicates which device should be used for classification.
        Models are by default run on CPU.

    Attributes
    ----------
    classes_: array of str
        Class names learned from the labels.
    """

    def __init__(
        self,
        model_name: str = "facebook/bart-large-mnli",
        progress_bar: bool = True,
        device: str = "cpu",
    ):
        self.model_name = model_name
        self.pipe = pipeline(model=model_name, device=device)
        self.classes_ = None
        self.pandas_out = False
        self.progress_bar = progress_bar

    def fit(self, X, y: Iterable[str]):
        """Learns class labels.

        Parameters
        ----------
        X: Any
            Ignored
        y: iterable of str
            Iterable of class labels.
            Should at least contain a representative sample
            of potential labels.

        Returns
        -------
        self
            Fitted model.
        """
        self.classes_ = np.array(list(set(y)))
        self.n_classes = len(self.classes_)
        return self

    def partial_fit(self, X, y: Iterable[str]):
        """Learns class labels.
        Can learn new labels if new are encountered in the data.

        Parameters
        ----------
        X: Any
            Ignored
        y: iterable of str
            Iterable of class labels.

        Returns
        -------
        self
            Fitted model.
        """
        if self.classes_ is None:
            self.classes_ = np.array(list(set(y)))
        else:
            new_labels = set(self.classes_) - set(y)
            if new_labels:
                self.classes_ = np.concatenate(self.classes_, list(new_labels))
        self.n_classes = len(self.classes_)
        return self

    def predict_proba(self, X: Iterable[str]) -> np.ndarray:
        """Predicts class probabilities for given texts.

        Parameters
        ----------
        X: iterable of str
            Texts to predict probabilities for.

        Returns
        -------
        array of shape (n_texts, n_classes)
            Class probabilities for each text.
        """
        if self.classes_ is None:
            raise NotFittedError(
                "No class labels have been learned by the model, please fit()."
            )
        X = list(X)
        n_texts = len(X)
        res = np.empty((n_texts, self.n_classes))
        if self.progress_bar:
            X = tqdm(X)
        for i_doc, text in enumerate(X):
            out = self.pipe(text, candidate_labels=self.classes_)
            label_to_score = dict(zip(out["labels"], out["scores"]))  # type: ignore
            for i_class, label in enumerate(self.classes_):
                res[i_doc, i_class] = label_to_score[label]
        return res

    def transform(self, X: Iterable[str]):
        """Predicts class probabilities for given texts.

        Parameters
        ----------
        X: iterable of str
            Texts to predict probabilities for.

        Returns
        -------
        array of shape (n_texts, n_classes)
            Class probabilities for each text.
        """
        res = self.predict_proba(X)
        if self.pandas_out:
            import pandas as pd

            return pd.DataFrame(res, columns=self.classes_)
        else:
            return res

    def predict(self, X: Iterable[str]) -> np.ndarray:
        """Predicts most probable class label for given texts.

        Parameters
        ----------
        X: iterable of str
            Texts to label.

        Returns
        -------
        array of shape (n_texts)
            Array of string class labels.
        """
        probs = self.transform(X)
        label_indices = np.argmax(probs, axis=1)
        return self.classes_[label_indices]  # type: ignore

    def get_feature_names_out(self) -> np.ndarray:
        if self.classes_ is None:
            raise NotFittedError(
                "No class labels have been learned by the model, please fit()."
            )
        return self.classes_

    @property
    def class_to_index(self) -> dict[str, int]:
        if self.classes_ is None:
            raise NotFittedError(
                "No class labels have been learned by the model, please fit()."
            )
        return dict(zip(self.classes_, range(self.n_classes)))

    def set_output(self, transform=None):
        """Set output of the transform() function to be a dataframe instead of
        a matrix if you pass transform='pandas'.
        Otherwise it will disable pandas output."""
        if transform == "pandas":
            self.pandas_out = True
        else:
            self.pandas_out = False
        return self

fit(X, y)

Learns class labels.

Parameters:

Name Type Description Default
X

Ignored

required
y Iterable[str]

Iterable of class labels. Should at least contain a representative sample of potential labels.

required

Returns:

Type Description
self

Fitted model.

Source code in stormtrooper/zero_shot.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def fit(self, X, y: Iterable[str]):
    """Learns class labels.

    Parameters
    ----------
    X: Any
        Ignored
    y: iterable of str
        Iterable of class labels.
        Should at least contain a representative sample
        of potential labels.

    Returns
    -------
    self
        Fitted model.
    """
    self.classes_ = np.array(list(set(y)))
    self.n_classes = len(self.classes_)
    return self

partial_fit(X, y)

Learns class labels. Can learn new labels if new are encountered in the data.

Parameters:

Name Type Description Default
X

Ignored

required
y Iterable[str]

Iterable of class labels.

required

Returns:

Type Description
self

Fitted model.

Source code in stormtrooper/zero_shot.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def partial_fit(self, X, y: Iterable[str]):
    """Learns class labels.
    Can learn new labels if new are encountered in the data.

    Parameters
    ----------
    X: Any
        Ignored
    y: iterable of str
        Iterable of class labels.

    Returns
    -------
    self
        Fitted model.
    """
    if self.classes_ is None:
        self.classes_ = np.array(list(set(y)))
    else:
        new_labels = set(self.classes_) - set(y)
        if new_labels:
            self.classes_ = np.concatenate(self.classes_, list(new_labels))
    self.n_classes = len(self.classes_)
    return self

predict(X)

Predicts most probable class label for given texts.

Parameters:

Name Type Description Default
X Iterable[str]

Texts to label.

required

Returns:

Type Description
array of shape (n_texts)

Array of string class labels.

Source code in stormtrooper/zero_shot.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def predict(self, X: Iterable[str]) -> np.ndarray:
    """Predicts most probable class label for given texts.

    Parameters
    ----------
    X: iterable of str
        Texts to label.

    Returns
    -------
    array of shape (n_texts)
        Array of string class labels.
    """
    probs = self.transform(X)
    label_indices = np.argmax(probs, axis=1)
    return self.classes_[label_indices]  # type: ignore

predict_proba(X)

Predicts class probabilities for given texts.

Parameters:

Name Type Description Default
X Iterable[str]

Texts to predict probabilities for.

required

Returns:

Type Description
array of shape (n_texts, n_classes)

Class probabilities for each text.

Source code in stormtrooper/zero_shot.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def predict_proba(self, X: Iterable[str]) -> np.ndarray:
    """Predicts class probabilities for given texts.

    Parameters
    ----------
    X: iterable of str
        Texts to predict probabilities for.

    Returns
    -------
    array of shape (n_texts, n_classes)
        Class probabilities for each text.
    """
    if self.classes_ is None:
        raise NotFittedError(
            "No class labels have been learned by the model, please fit()."
        )
    X = list(X)
    n_texts = len(X)
    res = np.empty((n_texts, self.n_classes))
    if self.progress_bar:
        X = tqdm(X)
    for i_doc, text in enumerate(X):
        out = self.pipe(text, candidate_labels=self.classes_)
        label_to_score = dict(zip(out["labels"], out["scores"]))  # type: ignore
        for i_class, label in enumerate(self.classes_):
            res[i_doc, i_class] = label_to_score[label]
    return res

set_output(transform=None)

Set output of the transform() function to be a dataframe instead of a matrix if you pass transform='pandas'. Otherwise it will disable pandas output.

Source code in stormtrooper/zero_shot.py
173
174
175
176
177
178
179
180
181
def set_output(self, transform=None):
    """Set output of the transform() function to be a dataframe instead of
    a matrix if you pass transform='pandas'.
    Otherwise it will disable pandas output."""
    if transform == "pandas":
        self.pandas_out = True
    else:
        self.pandas_out = False
    return self

transform(X)

Predicts class probabilities for given texts.

Parameters:

Name Type Description Default
X Iterable[str]

Texts to predict probabilities for.

required

Returns:

Type Description
array of shape (n_texts, n_classes)

Class probabilities for each text.

Source code in stormtrooper/zero_shot.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def transform(self, X: Iterable[str]):
    """Predicts class probabilities for given texts.

    Parameters
    ----------
    X: iterable of str
        Texts to predict probabilities for.

    Returns
    -------
    array of shape (n_texts, n_classes)
        Class probabilities for each text.
    """
    res = self.predict_proba(X)
    if self.pandas_out:
        import pandas as pd

        return pd.DataFrame(res, columns=self.classes_)
    else:
        return res