Model training
Config File¶
The input config file for training a SAFE
model is very similar to the GPT2 config file, with the addition of an optional num_labels
attribute for training with descriptors regularization.
{
"activation_function": "gelu_new",
"attn_pdrop": 0.1,
"bos_token_id": 10000,
"embd_pdrop": 0.1,
"eos_token_id": 1,
"initializer_range": 0.02,
"layer_norm_epsilon": 1e-05,
"model_type": "gpt2",
"n_embd": 768,
"n_head": 12,
"n_inner": null,
"n_layer": 12,
"n_positions": 1024,
"reorder_and_upcast_attn": false,
"resid_pdrop": 0.1,
"scale_attn_by_inverse_layer_idx": false,
"scale_attn_weights": true,
"summary_activation": "tanh",
"summary_first_dropout": 0.1,
"summary_proj_to_labels": true,
"summary_type": "cls_index",
"summary_hidden_size": 128,
"summary_use_proj": true,
"transformers_version": "4.31.0",
"use_cache": true,
"vocab_size": 10000,
"num_labels": 9
}
SAFE Model¶
PropertyHead
¶
Bases: Module
Compute a single vector summary of a sequence hidden states.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config
|
[`PretrainedConfig`]
|
The config used by the model. Relevant arguments in the config class of the model are (refer to the actual config class of your model for the default values it uses):
|
required |
Source code in safe/trainer/model.py
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
|
forward(hidden_states, cls_index=None)
¶
Compute a single vector summary of a sequence hidden states.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
hidden_states
|
FloatTensor
|
|
required |
cls_index
|
Optional[LongTensor]
|
|
None
|
Returns:
Type | Description |
---|---|
FloatTensor
|
|
Source code in safe/trainer/model.py
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
|
SAFEDoubleHeadsModel
¶
Bases: GPT2DoubleHeadsModel
The safe model is a dual head GPT2 model with a language modeling head and an optional multi-task regression head
Source code in safe/trainer/model.py
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
|
forward(input_ids=None, past_key_values=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, mc_token_ids=None, labels=None, mc_labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, inputs=None, encoder_hidden_states=None, **kwargs)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mc_token_ids
|
`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input
|
Index of the classification token in each input sequence. Selected in the range |
None
|
labels
|
`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*
|
Labels for language modeling. Note that the labels are shifted inside the model, i.e. you can set
|
None
|
mc_labels
|
`torch.LongTensor` of shape `(batch_size, n_tasks)`, *optional*
|
Labels for computing the supervized loss for regularization. |
None
|
inputs
|
Optional[Any]
|
List of inputs, put here because the trainer removes information not in signature |
None
|
Returns: output (GPT2DoubleHeadsModelOutput): output of the model
Source code in safe/trainer/model.py
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
|
Trainer¶
SAFETrainer
¶
Bases: Trainer
Custom trainer for training SAFE model.
This custom trainer changes the loss function to support the property head
Source code in safe/trainer/trainer_utils.py
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
|
compute_loss(model, inputs, return_outputs=False)
¶
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Source code in safe/trainer/trainer_utils.py
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
|
Data Collator¶
SAFECollator
¶
Collate function for language modelling tasks
Note
The collate function is based on the default DataCollatorForLanguageModeling in huggingface see: https://github.com/huggingface/transformers/blob/v4.19.2/src/transformers/data/data_collator.py
Source code in safe/trainer/collator.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
|
__call__(samples)
¶
Call collate function
Parameters:
Name | Type | Description | Default |
---|---|---|---|
samples
|
List[Union[List[int], Any, Dict[str, Any]]]
|
list of examples |
required |
Source code in safe/trainer/collator.py
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
|
__init__(tokenizer, pad_to_multiple_of=None, input_key='inputs', label_key='labels', property_key='descriptors', include_descriptors=False, max_length=None)
¶
Default collator for huggingface transformers in izanagi.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tokenizer
|
Tokenizer
|
Huggingface tokenizer |
required |
input_key
|
str
|
key to use for input ids |
'inputs'
|
label_key
|
str
|
key to use for labels |
'labels'
|
property_key
|
str
|
key to use for properties |
'descriptors'
|
include_descriptors
|
bool
|
whether to include training on descriptors or not |
False
|
pad_to_multiple_of
|
Optional[int]
|
pad to multiple of this value |
None
|
Source code in safe/trainer/collator.py
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
|
get_tokenizer()
cached
¶
Get underlying tokenizer
Source code in safe/trainer/collator.py
52 53 54 55 56 57 |
|
Data Utils¶
get_dataset(data_path, name=None, tokenizer=None, cache_dir=None, streaming=True, use_auth_token=False, tokenize_column='inputs', property_column='descriptors', max_length=None, num_shards=1024)
¶
Get the datasets from the config file
Source code in safe/trainer/data_utils.py
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
|
get_dataset_column_names(dataset)
¶
Get the column names in a dataset
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset
|
Union[Dataset, IterableDataset, Mapping]
|
dataset to get the column names from |
required |
Source code in safe/trainer/data_utils.py
18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
|
take(n, iterable)
¶
Return first n items of the iterable as a list
Source code in safe/trainer/data_utils.py
13 14 15 |
|
tokenize_fn(row, tokenizer, tokenize_column='inputs', max_length=None, padding=False)
¶
Perform the tokenization of a row Args: row: row to tokenize tokenizer: tokenizer to use tokenize_column: column to tokenize max_length: maximum size of the tokenized sequence padding: whether to pad the sequence
Source code in safe/trainer/data_utils.py
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
|