Skip to content

Model training

Config File

The input config file for training a SAFE model is very similar to the GPT2 config file, with the addition of an optional num_labels attribute for training with descriptors regularization.

{
  "activation_function": "gelu_new",
  "attn_pdrop": 0.1,
  "bos_token_id": 10000,
  "embd_pdrop": 0.1,
  "eos_token_id": 1,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": "tanh",
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_hidden_size": 128,
  "summary_use_proj": true,
  "transformers_version": "4.31.0",
  "use_cache": true,
  "vocab_size": 10000,
  "num_labels": 9
}

SAFE Model

PropertyHead

Bases: Module

Compute a single vector summary of a sequence hidden states.

Parameters:

Name Type Description Default
config [`PretrainedConfig`]

The config used by the model. Relevant arguments in the config class of the model are (refer to the actual config class of your model for the default values it uses):

  • summary_type (str) -- The method to use to make this summary. Accepted values are:

    - "last" -- Take the last token hidden state (like XLNet) - "first" -- Take the first token hidden state (like Bert) - "mean" -- Take the mean of all tokens hidden states - "cls_index" -- Supply a Tensor of classification token position (GPT/GPT-2)

  • summary_activation (Optional[str]) -- Set to "tanh" to add a tanh activation to the output, another string, or None to add no activation.
required
Source code in safe/trainer/model.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class PropertyHead(torch.nn.Module):
    r"""
    Compute a single vector summary of a sequence hidden states.

    Args:
        config ([`PretrainedConfig`]):
            The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
            config class of your model for the default values it uses):

            - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:

                - `"last"` -- Take the last token hidden state (like XLNet)
                - `"first"` -- Take the first token hidden state (like Bert)
                - `"mean"` -- Take the mean of all tokens hidden states
                - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)

            - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
              another string, or `None` to add no activation.
    """

    def __init__(self, config: PretrainedConfig):
        super().__init__()

        self.summary_type = getattr(config, "summary_type", "cls_index")
        self.summary = torch.nn.Identity()
        last_hidden_size = config.hidden_size

        if getattr(config, "summary_hidden_size", None) and config.summary_hidden_size > 0:
            self.summary = nn.Linear(config.hidden_size, config.summary_hidden_size)
            last_hidden_size = config.summary_hidden_size

        activation_string = getattr(config, "summary_activation", None)
        self.activation: Callable = (
            get_activation(activation_string) if activation_string else nn.Identity()
        )

        self.out = torch.nn.Identity()
        if getattr(config, "num_labels", None) and config.num_labels > 0:
            num_labels = config.num_labels
            self.out = nn.Linear(last_hidden_size, num_labels)

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        cls_index: Optional[torch.LongTensor] = None,
    ) -> torch.FloatTensor:
        """
        Compute a single vector summary of a sequence hidden states.

        Args:
            hidden_states: `torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`)
                The hidden states of the last layer.
            cls_index: `torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]`
                where ... are optional leading dimensions of `hidden_states`, *optional*
                Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.

        Returns:
            `torch.FloatTensor`: The summary of the sequence hidden states.
        """
        if self.summary_type == "last":
            output = hidden_states[:, -1]
        elif self.summary_type == "first":
            output = hidden_states[:, 0]
        elif self.summary_type == "mean":
            output = hidden_states.mean(dim=1)
        elif self.summary_type == "cls_index":
            # if cls_index is None:
            #     cls_index = torch.full_like(
            #         hidden_states[..., :1, :],
            #         hidden_states.shape[-2] - 1,
            #         dtype=torch.long,
            #     )
            # else:
            #     cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
            #     cls_index = cls_index.expand(
            #         (-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)
            #     )

            # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
            # output = hidden_states.gather(-2, cls_index).squeeze(-2)  # shape (bsz, XX, hidden_size)
            batch_size = hidden_states.shape[0]
            output = hidden_states.squeeze()[torch.arange(batch_size), cls_index]
        else:
            raise NotImplementedError

        output = self.summary(output)
        output = self.activation(output)
        return self.out(output)

forward(hidden_states, cls_index=None)

Compute a single vector summary of a sequence hidden states.

Parameters:

Name Type Description Default
hidden_states FloatTensor

torch.FloatTensor of shape [batch_size, seq_len, hidden_size]) The hidden states of the last layer.

required
cls_index Optional[LongTensor]

torch.LongTensor of shape [batch_size] or [batch_size, ...] where ... are optional leading dimensions of hidden_states, optional Used if summary_type == "cls_index" and takes the last token of the sequence as classification token.

None

Returns:

Type Description
FloatTensor

torch.FloatTensor: The summary of the sequence hidden states.

Source code in safe/trainer/model.py
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def forward(
    self,
    hidden_states: torch.FloatTensor,
    cls_index: Optional[torch.LongTensor] = None,
) -> torch.FloatTensor:
    """
    Compute a single vector summary of a sequence hidden states.

    Args:
        hidden_states: `torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`)
            The hidden states of the last layer.
        cls_index: `torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]`
            where ... are optional leading dimensions of `hidden_states`, *optional*
            Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.

    Returns:
        `torch.FloatTensor`: The summary of the sequence hidden states.
    """
    if self.summary_type == "last":
        output = hidden_states[:, -1]
    elif self.summary_type == "first":
        output = hidden_states[:, 0]
    elif self.summary_type == "mean":
        output = hidden_states.mean(dim=1)
    elif self.summary_type == "cls_index":
        # if cls_index is None:
        #     cls_index = torch.full_like(
        #         hidden_states[..., :1, :],
        #         hidden_states.shape[-2] - 1,
        #         dtype=torch.long,
        #     )
        # else:
        #     cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
        #     cls_index = cls_index.expand(
        #         (-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)
        #     )

        # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
        # output = hidden_states.gather(-2, cls_index).squeeze(-2)  # shape (bsz, XX, hidden_size)
        batch_size = hidden_states.shape[0]
        output = hidden_states.squeeze()[torch.arange(batch_size), cls_index]
    else:
        raise NotImplementedError

    output = self.summary(output)
    output = self.activation(output)
    return self.out(output)

SAFEDoubleHeadsModel

Bases: GPT2DoubleHeadsModel

The safe model is a dual head GPT2 model with a language modeling head and an optional multi-task regression head

Source code in safe/trainer/model.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
class SAFEDoubleHeadsModel(GPT2DoubleHeadsModel):
    """The safe model is a dual head GPT2 model with a language modeling head and an optional multi-task regression head"""

    def __init__(self, config):
        self.num_labels = getattr(config, "num_labels", None)
        super().__init__(config)
        self.config.num_labels = self.num_labels
        del self.multiple_choice_head
        self.multiple_choice_head = PropertyHead(config)

    @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        mc_token_ids: Optional[torch.LongTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        mc_labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        inputs: Optional[Any] = None,  # do not remove because of trainer
        encoder_hidden_states: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]:
        r"""

        Args:
            mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
                Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
                1]`.
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
                `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to
                `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`
            mc_labels (`torch.LongTensor` of shape `(batch_size, n_tasks)`, *optional*):
                Labels for computing the supervized loss for regularization.
            inputs: List of inputs, put here because the trainer removes information not in signature
        Returns:
            output (GPT2DoubleHeadsModelOutput): output of the model
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            encoder_hidden_states=encoder_hidden_states,
        )

        hidden_states = transformer_outputs[0]
        lm_logits = self.lm_head(hidden_states)

        if mc_token_ids is None and self.config.pad_token_id is not None and input_ids is not None:
            mc_token_ids = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(
                lm_logits.device
            )

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.transformer.first_device)
            hidden_states = hidden_states.to(self.lm_head.weight.device)

        mc_loss = None
        mc_logits = None
        if mc_labels is not None and getattr(self.config, "num_labels", 0) > 0:
            mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
            mc_labels = mc_labels.to(mc_logits.device)
            loss_fct = MSELoss()
            mc_loss = loss_fct(
                mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1, mc_logits.size(-1))
            )

        lm_loss = None
        if labels is not None:
            labels = labels.to(lm_logits.device)
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = CrossEntropyLoss()
            lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        if not return_dict:
            output = (lm_logits, mc_logits) + transformer_outputs[1:]
            return (
                lm_loss,
                mc_loss,
            ) + output

        return GPT2DoubleHeadsModelOutput(
            loss=lm_loss,
            mc_loss=mc_loss,
            logits=lm_logits,
            mc_logits=mc_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )

forward(input_ids=None, past_key_values=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, mc_token_ids=None, labels=None, mc_labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, inputs=None, encoder_hidden_states=None, **kwargs)

Parameters:

Name Type Description Default
mc_token_ids `torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input

Index of the classification token in each input sequence. Selected in the range [0, input_ids.size(-1) - 1].

None
labels `torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*

Labels for language modeling. Note that the labels are shifted inside the model, i.e. you can set labels = input_ids. Indices are selected in [-100, 0, ..., config.vocab_size - 1]. All labels set to -100 are ignored (masked), the loss is only computed for labels in [0, ..., config.vocab_size - 1]

None
mc_labels `torch.LongTensor` of shape `(batch_size, n_tasks)`, *optional*

Labels for computing the supervized loss for regularization.

None
inputs Optional[Any]

List of inputs, put here because the trainer removes information not in signature

None

Returns: output (GPT2DoubleHeadsModelOutput): output of the model

Source code in safe/trainer/model.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
    self,
    input_ids: Optional[torch.LongTensor] = None,
    past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
    attention_mask: Optional[torch.FloatTensor] = None,
    token_type_ids: Optional[torch.LongTensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    head_mask: Optional[torch.FloatTensor] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    mc_token_ids: Optional[torch.LongTensor] = None,
    labels: Optional[torch.LongTensor] = None,
    mc_labels: Optional[torch.LongTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    inputs: Optional[Any] = None,  # do not remove because of trainer
    encoder_hidden_states: Optional[torch.Tensor] = None,
    **kwargs,
) -> Union[Tuple, GPT2DoubleHeadsModelOutput]:
    r"""

    Args:
        mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
            Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
            1]`.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to
            `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`
        mc_labels (`torch.LongTensor` of shape `(batch_size, n_tasks)`, *optional*):
            Labels for computing the supervized loss for regularization.
        inputs: List of inputs, put here because the trainer removes information not in signature
    Returns:
        output (GPT2DoubleHeadsModelOutput): output of the model
    """
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    transformer_outputs = self.transformer(
        input_ids,
        past_key_values=past_key_values,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        head_mask=head_mask,
        inputs_embeds=inputs_embeds,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
        encoder_hidden_states=encoder_hidden_states,
    )

    hidden_states = transformer_outputs[0]
    lm_logits = self.lm_head(hidden_states)

    if mc_token_ids is None and self.config.pad_token_id is not None and input_ids is not None:
        mc_token_ids = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(
            lm_logits.device
        )

    # Set device for model parallelism
    if self.model_parallel:
        torch.cuda.set_device(self.transformer.first_device)
        hidden_states = hidden_states.to(self.lm_head.weight.device)

    mc_loss = None
    mc_logits = None
    if mc_labels is not None and getattr(self.config, "num_labels", 0) > 0:
        mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
        mc_labels = mc_labels.to(mc_logits.device)
        loss_fct = MSELoss()
        mc_loss = loss_fct(
            mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1, mc_logits.size(-1))
        )

    lm_loss = None
    if labels is not None:
        labels = labels.to(lm_logits.device)
        shift_logits = lm_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        loss_fct = CrossEntropyLoss()
        lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

    if not return_dict:
        output = (lm_logits, mc_logits) + transformer_outputs[1:]
        return (
            lm_loss,
            mc_loss,
        ) + output

    return GPT2DoubleHeadsModelOutput(
        loss=lm_loss,
        mc_loss=mc_loss,
        logits=lm_logits,
        mc_logits=mc_logits,
        past_key_values=transformer_outputs.past_key_values,
        hidden_states=transformer_outputs.hidden_states,
        attentions=transformer_outputs.attentions,
    )

Trainer

SAFETrainer

Bases: Trainer

Custom trainer for training SAFE model.

This custom trainer changes the loss function to support the property head

Source code in safe/trainer/trainer_utils.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class SAFETrainer(Trainer):
    """
    Custom trainer for training SAFE model.

    This custom trainer changes the loss function to support the property head

    """

    def __init__(self, *args, prop_loss_coeff: float = 1e-3, **kwargs):
        super().__init__(*args, **kwargs)
        self.prop_loss_coeff = prop_loss_coeff

    def compute_loss(self, model, inputs, return_outputs=False):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.
        """
        labels = (
            inputs.pop("labels") if self.label_smoother is not None and "labels" in inputs else None
        )
        outputs = model(**inputs)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            unwrapped_model = self.accelerator.unwrap_model(model)
            if _is_peft_model(unwrapped_model):
                model_name = unwrapped_model.base_model.model._get_name()
            else:
                model_name = unwrapped_model._get_name()
            if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                loss = self.label_smoother(outputs, labels)
        else:
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        mc_loss = outputs.get("mc_loss", None) if isinstance(outputs, dict) else outputs[1]
        if mc_loss is not None:
            loss = loss + self.prop_loss_coeff * mc_loss
        return (loss, outputs) if return_outputs else loss

compute_loss(model, inputs, return_outputs=False)

How the loss is computed by Trainer. By default, all models return the loss in the first element.

Source code in safe/trainer/trainer_utils.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def compute_loss(self, model, inputs, return_outputs=False):
    """
    How the loss is computed by Trainer. By default, all models return the loss in the first element.
    """
    labels = (
        inputs.pop("labels") if self.label_smoother is not None and "labels" in inputs else None
    )
    outputs = model(**inputs)
    # Save past state if it exists
    # TODO: this needs to be fixed and made cleaner later.
    if self.args.past_index >= 0:
        self._past = outputs[self.args.past_index]

    if labels is not None:
        unwrapped_model = self.accelerator.unwrap_model(model)
        if _is_peft_model(unwrapped_model):
            model_name = unwrapped_model.base_model.model._get_name()
        else:
            model_name = unwrapped_model._get_name()
        if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
            loss = self.label_smoother(outputs, labels, shift_labels=True)
        else:
            loss = self.label_smoother(outputs, labels)
    else:
        if isinstance(outputs, dict) and "loss" not in outputs:
            raise ValueError(
                "The model did not return a loss from the inputs, only the following keys: "
                f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
            )
        # We don't use .loss here since the model may return tuples instead of ModelOutput.
        loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

    mc_loss = outputs.get("mc_loss", None) if isinstance(outputs, dict) else outputs[1]
    if mc_loss is not None:
        loss = loss + self.prop_loss_coeff * mc_loss
    return (loss, outputs) if return_outputs else loss

Data Collator

SAFECollator

Collate function for language modelling tasks

Note

The collate function is based on the default DataCollatorForLanguageModeling in huggingface see: https://github.com/huggingface/transformers/blob/v4.19.2/src/transformers/data/data_collator.py

Source code in safe/trainer/collator.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
class SAFECollator:
    """Collate function for language modelling tasks


    !!! note
        The collate function is based on the default DataCollatorForLanguageModeling in huggingface
        see: https://github.com/huggingface/transformers/blob/v4.19.2/src/transformers/data/data_collator.py
    """

    def __init__(
        self,
        tokenizer: Tokenizer,
        pad_to_multiple_of: Optional[int] = None,
        input_key: str = "inputs",
        label_key: str = "labels",
        property_key: str = "descriptors",
        include_descriptors: bool = False,
        max_length: Optional[int] = None,
    ):
        """
        Default collator for huggingface transformers in izanagi.

        Args:
            tokenizer: Huggingface tokenizer
            input_key: key to use for input ids
            label_key: key to use for labels
            property_key: key to use for properties
            include_descriptors: whether to include training on descriptors or not
            pad_to_multiple_of: pad to multiple of this value
        """

        self.tokenizer = tokenizer
        self.pad_to_multiple_of = pad_to_multiple_of
        self.input_key = input_key
        self.label_key = label_key
        self.property_key = property_key
        self.include_descriptors = include_descriptors
        self.max_length = max_length

    @functools.lru_cache()
    def get_tokenizer(self):
        """Get underlying tokenizer"""
        if isinstance(self.tokenizer, SAFETokenizer):
            return self.tokenizer.get_pretrained()
        return self.tokenizer

    def __call__(self, samples: List[Union[List[int], Any, Dict[str, Any]]]):
        """
        Call collate function

        Args:
            samples: list of examples
        """
        # Handle dict or lists with proper padding and conversion to tensor.
        tokenizer = self.get_tokenizer()

        # examples = samples
        examples = copy.deepcopy(samples)
        inputs = [example.pop(self.input_key, None) for example in examples]
        mc_labels = (
            torch.tensor([example.pop(self.property_key, None) for example in examples]).float()
            if self.property_key in examples[0]
            else None
        )

        if "input_ids" not in examples[0] and inputs is not None:
            batch = tokenizer(
                inputs,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=self.max_length,
                pad_to_multiple_of=self.pad_to_multiple_of,
            )
        else:
            batch = tokenizer.pad(
                examples,
                return_tensors="pt",
                padding=True,
                pad_to_multiple_of=self.pad_to_multiple_of,
                max_length=self.max_length,
            )

        # If special token mask has been preprocessed, pop it from the dict.
        batch.pop("special_tokens_mask", None)
        labels = batch.get(self.label_key, batch["input_ids"].clone())
        if tokenizer.pad_token_id is not None:
            labels[labels == tokenizer.pad_token_id] = -100
        batch[self.label_key] = labels

        if mc_labels is not None and self.include_descriptors:
            batch.update(
                {
                    "mc_labels": mc_labels,
                    # "input_text": inputs,
                }
            )
        return batch

__call__(samples)

Call collate function

Parameters:

Name Type Description Default
samples List[Union[List[int], Any, Dict[str, Any]]]

list of examples

required
Source code in safe/trainer/collator.py
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def __call__(self, samples: List[Union[List[int], Any, Dict[str, Any]]]):
    """
    Call collate function

    Args:
        samples: list of examples
    """
    # Handle dict or lists with proper padding and conversion to tensor.
    tokenizer = self.get_tokenizer()

    # examples = samples
    examples = copy.deepcopy(samples)
    inputs = [example.pop(self.input_key, None) for example in examples]
    mc_labels = (
        torch.tensor([example.pop(self.property_key, None) for example in examples]).float()
        if self.property_key in examples[0]
        else None
    )

    if "input_ids" not in examples[0] and inputs is not None:
        batch = tokenizer(
            inputs,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
        )
    else:
        batch = tokenizer.pad(
            examples,
            return_tensors="pt",
            padding=True,
            pad_to_multiple_of=self.pad_to_multiple_of,
            max_length=self.max_length,
        )

    # If special token mask has been preprocessed, pop it from the dict.
    batch.pop("special_tokens_mask", None)
    labels = batch.get(self.label_key, batch["input_ids"].clone())
    if tokenizer.pad_token_id is not None:
        labels[labels == tokenizer.pad_token_id] = -100
    batch[self.label_key] = labels

    if mc_labels is not None and self.include_descriptors:
        batch.update(
            {
                "mc_labels": mc_labels,
                # "input_text": inputs,
            }
        )
    return batch

__init__(tokenizer, pad_to_multiple_of=None, input_key='inputs', label_key='labels', property_key='descriptors', include_descriptors=False, max_length=None)

Default collator for huggingface transformers in izanagi.

Parameters:

Name Type Description Default
tokenizer Tokenizer

Huggingface tokenizer

required
input_key str

key to use for input ids

'inputs'
label_key str

key to use for labels

'labels'
property_key str

key to use for properties

'descriptors'
include_descriptors bool

whether to include training on descriptors or not

False
pad_to_multiple_of Optional[int]

pad to multiple of this value

None
Source code in safe/trainer/collator.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def __init__(
    self,
    tokenizer: Tokenizer,
    pad_to_multiple_of: Optional[int] = None,
    input_key: str = "inputs",
    label_key: str = "labels",
    property_key: str = "descriptors",
    include_descriptors: bool = False,
    max_length: Optional[int] = None,
):
    """
    Default collator for huggingface transformers in izanagi.

    Args:
        tokenizer: Huggingface tokenizer
        input_key: key to use for input ids
        label_key: key to use for labels
        property_key: key to use for properties
        include_descriptors: whether to include training on descriptors or not
        pad_to_multiple_of: pad to multiple of this value
    """

    self.tokenizer = tokenizer
    self.pad_to_multiple_of = pad_to_multiple_of
    self.input_key = input_key
    self.label_key = label_key
    self.property_key = property_key
    self.include_descriptors = include_descriptors
    self.max_length = max_length

get_tokenizer() cached

Get underlying tokenizer

Source code in safe/trainer/collator.py
52
53
54
55
56
57
@functools.lru_cache()
def get_tokenizer(self):
    """Get underlying tokenizer"""
    if isinstance(self.tokenizer, SAFETokenizer):
        return self.tokenizer.get_pretrained()
    return self.tokenizer

Data Utils

get_dataset(data_path, name=None, tokenizer=None, cache_dir=None, streaming=True, use_auth_token=False, tokenize_column='inputs', property_column='descriptors', max_length=None, num_shards=1024)

Get the datasets from the config file

Source code in safe/trainer/data_utils.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def get_dataset(
    data_path,
    name: Optional[str] = None,
    tokenizer: Optional[Callable] = None,
    cache_dir: Optional[str] = None,
    streaming: bool = True,
    use_auth_token: bool = False,
    tokenize_column: Optional[str] = "inputs",
    property_column: Optional[str] = "descriptors",
    max_length: Optional[int] = None,
    num_shards=1024,
):
    """Get the datasets from the config file"""
    raw_datasets = {}
    if data_path is not None:
        data_path = upath.UPath(str(data_path))

        if data_path.exists():
            # then we need to load from disk
            data_path = str(data_path)
            # for some reason, the datasets package is not able to load the dataset
            # because the split where not originally proposed
            raw_datasets = datasets.load_from_disk(data_path)

            if streaming:
                if isinstance(raw_datasets, datasets.DatasetDict):
                    previous_num_examples = {k: len(dt) for k, dt in raw_datasets.items()}
                    raw_datasets = datasets.IterableDatasetDict(
                        {
                            k: dt.to_iterable_dataset(num_shards=num_shards)
                            for k, dt in raw_datasets.items()
                        }
                    )
                    for k, dt in raw_datasets.items():
                        if previous_num_examples[k] is not None:
                            setattr(dt, "num_examples", previous_num_examples[k])
                else:
                    num_examples = len(raw_datasets)
                    raw_datasets = raw_datasets.to_iterable_dataset(num_shards=num_shards)
                    setattr(raw_datasets, "num_examples", num_examples)

        else:
            data_path = str(data_path)
            raw_datasets = datasets.load_dataset(
                data_path,
                name=name,
                cache_dir=cache_dir,
                use_auth_token=True if use_auth_token else None,
                streaming=streaming,
            )
    # that means we need to return a tokenized version of the dataset

    if property_column not in ["mc_labels", None]:
        raw_datasets = raw_datasets.rename_column(property_column, "mc_labels")

    columns_to_remove = None
    if tokenize_column is not None:
        columns_to_remove = [
            x
            for x in (get_dataset_column_names(raw_datasets) or [])
            if x not in [tokenize_column, "mc_labels"] and "label" not in x
        ] or None

    if tokenizer is None:
        if columns_to_remove is not None:
            raw_datasets = raw_datasets.remove_columns(columns_to_remove)
        return raw_datasets

    return raw_datasets.map(
        partial(
            tokenize_fn,
            tokenizer=tokenizer,
            tokenize_column=tokenize_column,
            max_length=max_length,
        ),
        batched=True,
        remove_columns=columns_to_remove,
    )

get_dataset_column_names(dataset)

Get the column names in a dataset

Parameters:

Name Type Description Default
dataset Union[Dataset, IterableDataset, Mapping]

dataset to get the column names from

required
Source code in safe/trainer/data_utils.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def get_dataset_column_names(dataset: Union[datasets.Dataset, datasets.IterableDataset, Mapping]):
    """Get the column names in a dataset

    Args:
        dataset: dataset to get the column names from

    """
    if isinstance(dataset, (datasets.IterableDatasetDict, Mapping)):
        column_names = {split: dataset[split].column_names for split in dataset}
    else:
        column_names = dataset.column_names
    if isinstance(column_names, dict):
        column_names = list(column_names.values())[0]
    return column_names

take(n, iterable)

Return first n items of the iterable as a list

Source code in safe/trainer/data_utils.py
13
14
15
def take(n, iterable):
    "Return first n items of the iterable as a list"
    return list(itertools.islice(iterable, n))

tokenize_fn(row, tokenizer, tokenize_column='inputs', max_length=None, padding=False)

Perform the tokenization of a row Args: row: row to tokenize tokenizer: tokenizer to use tokenize_column: column to tokenize max_length: maximum size of the tokenized sequence padding: whether to pad the sequence

Source code in safe/trainer/data_utils.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def tokenize_fn(
    row: Dict[str, Any],
    tokenizer: Callable,
    tokenize_column: str = "inputs",
    max_length: Optional[int] = None,
    padding: bool = False,
):
    """Perform the tokenization of a row
    Args:
        row: row to tokenize
        tokenizer: tokenizer to use
        tokenize_column: column to tokenize
        max_length: maximum size of the tokenized sequence
        padding: whether to pad the sequence
    """
    # there's probably a way to do this with the tokenizer settings
    # but again, gotta move fast

    fast_tokenizer = (
        tokenizer.get_pretrained() if isinstance(tokenizer, SAFETokenizer) else tokenizer
    )

    return fast_tokenizer(
        row[tokenize_column],
        truncation=(max_length is not None),
        max_length=max_length,
        padding=padding,
        return_tensors=None,
    )