Skip to content

SAFE

SAFE Encoder-Decoder

SAFEConverter

Molecule line notation conversion from SMILES to SAFE

A SAFE representation is a string based representation of a molecule decomposition into fragment components, separated by a dot ('.'). Note that each component (fragment) might not be a valid molecule by themselves, unless explicitely correct to add missing hydrogens.

Slicing algorithms

By default SAFE strings are generated using BRICS, however, the following alternative are supported:

Furthermore, you can also provide your own slicing algorithm, which should return a pair of atoms corresponding to the bonds to break.

Source code in safe/converter.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
class SAFEConverter:
    """Molecule line notation conversion from SMILES to SAFE

    A SAFE representation is a string based representation of a molecule decomposition into fragment components,
    separated by a dot ('.'). Note that each component (fragment) might not be a valid molecule by themselves,
    unless explicitely correct to add missing hydrogens.

    !!! note "Slicing algorithms"

        By default SAFE strings are generated using `BRICS`, however, the following alternative are supported:

        * [Hussain-Rea (`hr`)](https://pubs.acs.org/doi/10.1021/ci900450m)
        * [RECAP (`recap`)](https://pubmed.ncbi.nlm.nih.gov/9611787/)
        * [RDKit's MMPA (`mmpa`)](https://www.rdkit.org/docs/source/rdkit.Chem.rdMMPA.html)
        * Any possible attachment points (`attach`)

        Furthermore, you can also provide your own slicing algorithm, which should return a pair of atoms
        corresponding to the bonds to break.

    """

    SUPPORTED_SLICERS = ["hr", "rotatable", "recap", "mmpa", "attach", "brics"]
    __SLICE_SMARTS = {
        "hr": ["[*]!@-[*]"],  # any non ring single bond
        "recap": [
            "[$([C;!$(C([#7])[#7])](=!@[O]))]!@[$([#7;+0;!D1])]",
            "[$(C=!@O)]!@[$([O;+0])]",
            "[$([N;!D1;+0;!$(N-C=[#7,#8,#15,#16])](-!@[*]))]-!@[$([*])]",
            "[$(C(=!@O)([#7;+0;D2,D3])!@[#7;+0;D2,D3])]!@[$([#7;+0;D2,D3])]",
            "[$([O;+0](-!@[#6!$(C=O)])-!@[#6!$(C=O)])]-!@[$([#6!$(C=O)])]",
            "C=!@C",
            "[N;+1;D4]!@[#6]",
            "[$([n;+0])]-!@C",
            "[$([O]=[C]-@[N;+0])]-!@[$([C])]",
            "c-!@c",
            "[$([#7;+0;D2,D3])]-!@[$([S](=[O])=[O])]",
        ],
        "mmpa": ["[#6+0;!$(*=,#[!#6])]!@!=!#[*]"],  # classical mmpa slicing smarts
        "attach": ["[*]!@[*]"],  # any potential attachment point, including hydrogens when explicit
        "rotatable": ["[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]"],
    }

    def __init__(
        self,
        slicer: Optional[Union[str, List[str], Callable]] = "brics",
        require_hs: Optional[bool] = None,
        use_original_opener_for_attach: bool = True,
        ignore_stereo: bool = False,
    ):
        """Constructor for the SAFE converter

        Args:
            slicer: slicer algorithm to use for encoding.
                Can either be one of the supported slicing algorithm (SUPPORTED_SLICERS)
                or a custom callable that returns the bond ids that can be sliced.
            require_hs: whether the slicing algorithm require the molecule to have hydrogen explictly added.
                `attach` slicer requires adding hydrogens.
            use_original_opener_for_attach: whether to use the original branch opener digit when adding back
                mapping number to attachment points, or use simple enumeration.
            ignore_stereo: RDKIT does not support some particular SAFE subset when stereochemistry is defined.

        """
        self.slicer = slicer
        if isinstance(slicer, str) and slicer.lower() in self.SUPPORTED_SLICERS:
            self.slicer = self.__SLICE_SMARTS.get(slicer.lower(), slicer)
        if self.slicer != "brics" and isinstance(self.slicer, str):
            self.slicer = [self.slicer]
        if isinstance(self.slicer, (list, tuple)):
            self.slicer = [dm.from_smarts(x) for x in self.slicer]
            if any(x is None for x in self.slicer):
                raise ValueError(f"Slicer: {slicer} cannot be valid")
        self.require_hs = require_hs or (slicer == "attach")
        self.use_original_opener_for_attach = use_original_opener_for_attach
        self.ignore_stereo = ignore_stereo

    @staticmethod
    def randomize(mol: dm.Mol, rng: Optional[int] = None):
        """Randomize the position of the atoms in a mol.

        Args:
            mol: molecules to randomize
            rng: optional seed to use
        """
        if isinstance(rng, int):
            rng = np.random.default_rng(rng)
        if mol.GetNumAtoms() == 0:
            return mol
        atom_indices = list(range(mol.GetNumAtoms()))
        atom_indices = rng.permutation(atom_indices).tolist()
        return Chem.RenumberAtoms(mol, atom_indices)

    @classmethod
    def _find_branch_number(cls, inp: str):
        """Find the branch number and ring closure in the SMILES representation using regexp

        Args:
            inp: input smiles
        """
        inp = re.sub(r"\[.*?\]", "", inp)  # noqa
        matching_groups = re.findall(r"((?<=%)\d{2})|((?<!%)\d+)(?![^\[]*\])", inp)
        # first match is for multiple connection as multiple digits
        # second match is for single connections requiring 2 digits
        # SMILES does not support triple digits
        branch_numbers = []
        for m in matching_groups:
            if m[0] == "":
                branch_numbers.extend(int(mm) for mm in m[1])
            elif m[1] == "":
                branch_numbers.append(int(m[0].replace("%", "")))
        return branch_numbers

    def _ensure_valid(self, inp: str):
        """Ensure that the input SAFE string is valid by fixing the missing attachment points

        Args:
            inp: input SAFE string

        """
        missing_tokens = [inp]
        branch_numbers = self._find_branch_number(inp)
        # only use the set that have exactly 1 element
        # any branch number that is not pairwise should receive a dummy atom to complete the attachment point
        branch_numbers = Counter(branch_numbers)
        for i, (bnum, bcount) in enumerate(branch_numbers.items()):
            if bcount % 2 != 0:
                bnum_str = str(bnum) if bnum < 10 else f"%{bnum}"
                _tk = f"[*:{i+1}]{bnum_str}"
                if self.use_original_opener_for_attach:
                    bnum_digit = bnum_str.strip("%")  # strip out the % sign
                    _tk = f"[*:{bnum_digit}]{bnum_str}"
                missing_tokens.append(_tk)
        return ".".join(missing_tokens)

    def decoder(
        self,
        inp: str,
        as_mol: bool = False,
        canonical: bool = False,
        fix: bool = True,
        remove_dummies: bool = True,
        remove_added_hs: bool = True,
    ):
        """Convert input SAFE representation to smiles

        Args:
            inp: input SAFE representation to decode as a valid molecule or smiles
            as_mol: whether to return a molecule object or a smiles string
            canonical: whether to return a canonical
            fix: whether to fix the SAFE representation to take into account non-connected attachment points
            remove_dummies: whether to remove dummy atoms from the SAFE representation. Note that removing_dummies is incompatible with
            remove_added_hs: whether to remove all the added hydrogen atoms after applying dummy removal for recovery
        """

        if fix:
            inp = self._ensure_valid(inp)
        mol = dm.to_mol(inp)
        if remove_dummies:
            with suppress(Exception):
                du = dm.from_smarts("[$([#0]!-!:*);$([#0;D1])]")
                out = Chem.ReplaceSubstructs(mol, du, dm.to_mol("C"), True)[0]
                mol = dm.remove_dummies(out)
        if as_mol:
            if remove_added_hs:
                mol = dm.remove_hs(mol, update_explicit_count=True)
            if canonical:
                mol = dm.standardize_mol(mol)
                mol = dm.canonical_tautomer(mol)
            return mol
        out = dm.to_smiles(mol, canonical=canonical, explicit_hs=(not remove_added_hs))
        if canonical:
            out = dm.standardize_smiles(out)
        return out

    def _fragment(self, mol: dm.Mol, allow_empty: bool = False):
        """
        Perform bond cutting in place for the input molecule, given the slicing algorithm

        Args:
            mol: input molecule to split
            allow_empty: whether to allow the slicing algorithm to return empty bonds
        Raises:
            SAFEFragmentationError: if the slicing algorithm return empty bonds
        """

        if self.slicer is None:
            matching_bonds = []

        elif callable(self.slicer):
            matching_bonds = self.slicer(mol)
            matching_bonds = list(matching_bonds)

        elif self.slicer == "brics":
            matching_bonds = BRICS.FindBRICSBonds(mol)
            matching_bonds = [brics_match[0] for brics_match in matching_bonds]

        else:
            matches = set()
            for smarts in self.slicer:
                matches |= {
                    tuple(sorted(match)) for match in mol.GetSubstructMatches(smarts, uniquify=True)
                }
            matching_bonds = list(matches)

        if matching_bonds is None or len(matching_bonds) == 0 and not allow_empty:
            raise SAFEFragmentationError(
                "Slicing algorithms did not return any bonds that can be cut !"
            )
        return matching_bonds or []

    def encoder(
        self,
        inp: Union[str, dm.Mol],
        canonical: bool = True,
        randomize: Optional[bool] = False,
        seed: Optional[int] = None,
        constraints: Optional[List[dm.Mol]] = None,
        allow_empty: bool = False,
        rdkit_safe: bool = True,
    ):
        """Convert input smiles to SAFE representation

        Args:
            inp: input smiles
            canonical: whether to return canonical smiles string. Defaults to True
            randomize: whether to randomize the safe string encoding. Will be ignored if canonical is provided
            seed: optional seed to use when allowing randomization of the SAFE encoding.
                Randomization happens at two steps:
                1. at the original smiles representation by randomization the atoms.
                2. at the SAFE conversion by randomizing fragment orders
            constraints: List of molecules or pattern to preserve during the SAFE construction. Any bond slicing would
                happen outside of a substructure matching one of the patterns.
            allow_empty: whether to allow the slicing algorithm to return empty bonds
            rdkit_safe: whether to apply rdkit-safe digit standardization to the output SAFE string.
        """
        rng = None
        if randomize:
            rng = np.random.default_rng(seed)
            if not canonical:
                inp = dm.to_mol(inp, remove_hs=False)
                inp = self.randomize(inp, rng)

        if isinstance(inp, dm.Mol):
            inp = dm.to_smiles(inp, canonical=canonical, randomize=False, ordered=False)

        # EN: we first normalize the attachment if the molecule is a query:
        # inp = dm.reactions.convert_attach_to_isotope(inp, as_smiles=True)

        # TODO(maclandrol): RDKit supports some extended form of ring closure, up to 5 digits
        # https://www.rdkit.org/docs/RDKit_Book.html#ring-closures and I should try to include them
        branch_numbers = self._find_branch_number(inp)

        mol = dm.to_mol(inp, remove_hs=False)
        potential_stereos = Chem.FindPotentialStereo(mol)
        has_stereo_bonds = any(x.type == Chem.StereoType.Bond_Double for x in potential_stereos)
        if self.ignore_stereo:
            mol = dm.remove_stereochemistry(mol)

        bond_map_id = 1
        for atom in mol.GetAtoms():
            if atom.GetAtomicNum() == 0:
                atom.SetAtomMapNum(0)
                atom.SetIsotope(bond_map_id)
                bond_map_id += 1

        if self.require_hs:
            mol = dm.add_hs(mol)
        matching_bonds = self._fragment(mol, allow_empty=allow_empty)
        substructed_ignored = []
        if constraints is not None:
            substructed_ignored = list(
                itertools.chain(
                    *[
                        mol.GetSubstructMatches(constraint, uniquify=True)
                        for constraint in constraints
                    ]
                )
            )

        bonds = []
        for i_a, i_b in matching_bonds:
            # if both atoms of the bond are found in a disallowed substructure, we cannot consider them
            # on the other end, a bond between two substructure to preserved independently is perfectly fine
            if any((i_a in ignore_x and i_b in ignore_x) for ignore_x in substructed_ignored):
                continue
            obond = mol.GetBondBetweenAtoms(i_a, i_b)
            bonds.append(obond.GetIdx())

        if len(bonds) > 0:
            mol = Chem.FragmentOnBonds(
                mol,
                bonds,
                dummyLabels=[(i + bond_map_id, i + bond_map_id) for i in range(len(bonds))],
            )
        # here we need to be clever and disable rooted atom as the atom with mapping

        frags = list(Chem.GetMolFrags(mol, asMols=True))
        if randomize:
            frags = rng.permutation(frags).tolist()
        elif canonical:
            frags = sorted(
                frags,
                key=lambda x: x.GetNumAtoms(),
                reverse=True,
            )

        frags_str = []
        for frag in frags:
            non_map_atom_idxs = [
                atom.GetIdx() for atom in frag.GetAtoms() if atom.GetAtomicNum() != 0
            ]
            frags_str.append(
                Chem.MolToSmiles(
                    frag,
                    isomericSmiles=True,
                    canonical=True,  # needs to always be true
                    rootedAtAtom=non_map_atom_idxs[0],
                )
            )

        scaffold_str = ".".join(frags_str)
        # EN: fix for https://github.com/datamol-io/safe/issues/37
        # we were using the wrong branch number count which did not take into account
        # possible change in digit utilization after bond slicing
        scf_branch_num = self._find_branch_number(scaffold_str) + branch_numbers

        # don't capture atom mapping in the scaffold
        attach_pos = set(re.findall(r"(\[\d+\*\]|!\[[^:]*:\d+\])", scaffold_str))
        if canonical:
            attach_pos = sorted(attach_pos)
        starting_num = 1 if len(scf_branch_num) == 0 else max(scf_branch_num) + 1
        for attach in attach_pos:
            val = str(starting_num) if starting_num < 10 else f"%{starting_num}"
            # we cannot have anything of the form "\([@=-#-$/\]*\d+\)"
            attach_regexp = re.compile(r"(" + re.escape(attach) + r")")
            scaffold_str = attach_regexp.sub(val, scaffold_str)
            starting_num += 1

        # now we need to remove all the parenthesis around digit only number
        wrong_attach = re.compile(r"\(([\%\d]*)\)")
        scaffold_str = wrong_attach.sub(r"\g<1>", scaffold_str)
        # furthermore, we autoapply rdkit-compatible digit standardization.
        if rdkit_safe:
            pattern = r"\(([=-@#\/\\]{0,2})(%?\d{1,2})\)"
            replacement = r"\g<1>\g<2>"
            scaffold_str = re.sub(pattern, replacement, scaffold_str)
        if not self.ignore_stereo and has_stereo_bonds and not dm.same_mol(scaffold_str, inp):
            logger.warning(
                "Ignoring stereo is disabled, but molecule has stereochemistry interferring with SAFE representation"
            )
        return scaffold_str

__init__(slicer='brics', require_hs=None, use_original_opener_for_attach=True, ignore_stereo=False)

Constructor for the SAFE converter

Parameters:

Name Type Description Default
slicer Optional[Union[str, List[str], Callable]]

slicer algorithm to use for encoding. Can either be one of the supported slicing algorithm (SUPPORTED_SLICERS) or a custom callable that returns the bond ids that can be sliced.

'brics'
require_hs Optional[bool]

whether the slicing algorithm require the molecule to have hydrogen explictly added. attach slicer requires adding hydrogens.

None
use_original_opener_for_attach bool

whether to use the original branch opener digit when adding back mapping number to attachment points, or use simple enumeration.

True
ignore_stereo bool

RDKIT does not support some particular SAFE subset when stereochemistry is defined.

False
Source code in safe/converter.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def __init__(
    self,
    slicer: Optional[Union[str, List[str], Callable]] = "brics",
    require_hs: Optional[bool] = None,
    use_original_opener_for_attach: bool = True,
    ignore_stereo: bool = False,
):
    """Constructor for the SAFE converter

    Args:
        slicer: slicer algorithm to use for encoding.
            Can either be one of the supported slicing algorithm (SUPPORTED_SLICERS)
            or a custom callable that returns the bond ids that can be sliced.
        require_hs: whether the slicing algorithm require the molecule to have hydrogen explictly added.
            `attach` slicer requires adding hydrogens.
        use_original_opener_for_attach: whether to use the original branch opener digit when adding back
            mapping number to attachment points, or use simple enumeration.
        ignore_stereo: RDKIT does not support some particular SAFE subset when stereochemistry is defined.

    """
    self.slicer = slicer
    if isinstance(slicer, str) and slicer.lower() in self.SUPPORTED_SLICERS:
        self.slicer = self.__SLICE_SMARTS.get(slicer.lower(), slicer)
    if self.slicer != "brics" and isinstance(self.slicer, str):
        self.slicer = [self.slicer]
    if isinstance(self.slicer, (list, tuple)):
        self.slicer = [dm.from_smarts(x) for x in self.slicer]
        if any(x is None for x in self.slicer):
            raise ValueError(f"Slicer: {slicer} cannot be valid")
    self.require_hs = require_hs or (slicer == "attach")
    self.use_original_opener_for_attach = use_original_opener_for_attach
    self.ignore_stereo = ignore_stereo

decoder(inp, as_mol=False, canonical=False, fix=True, remove_dummies=True, remove_added_hs=True)

Convert input SAFE representation to smiles

Parameters:

Name Type Description Default
inp str

input SAFE representation to decode as a valid molecule or smiles

required
as_mol bool

whether to return a molecule object or a smiles string

False
canonical bool

whether to return a canonical

False
fix bool

whether to fix the SAFE representation to take into account non-connected attachment points

True
remove_dummies bool

whether to remove dummy atoms from the SAFE representation. Note that removing_dummies is incompatible with

True
remove_added_hs bool

whether to remove all the added hydrogen atoms after applying dummy removal for recovery

True
Source code in safe/converter.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
def decoder(
    self,
    inp: str,
    as_mol: bool = False,
    canonical: bool = False,
    fix: bool = True,
    remove_dummies: bool = True,
    remove_added_hs: bool = True,
):
    """Convert input SAFE representation to smiles

    Args:
        inp: input SAFE representation to decode as a valid molecule or smiles
        as_mol: whether to return a molecule object or a smiles string
        canonical: whether to return a canonical
        fix: whether to fix the SAFE representation to take into account non-connected attachment points
        remove_dummies: whether to remove dummy atoms from the SAFE representation. Note that removing_dummies is incompatible with
        remove_added_hs: whether to remove all the added hydrogen atoms after applying dummy removal for recovery
    """

    if fix:
        inp = self._ensure_valid(inp)
    mol = dm.to_mol(inp)
    if remove_dummies:
        with suppress(Exception):
            du = dm.from_smarts("[$([#0]!-!:*);$([#0;D1])]")
            out = Chem.ReplaceSubstructs(mol, du, dm.to_mol("C"), True)[0]
            mol = dm.remove_dummies(out)
    if as_mol:
        if remove_added_hs:
            mol = dm.remove_hs(mol, update_explicit_count=True)
        if canonical:
            mol = dm.standardize_mol(mol)
            mol = dm.canonical_tautomer(mol)
        return mol
    out = dm.to_smiles(mol, canonical=canonical, explicit_hs=(not remove_added_hs))
    if canonical:
        out = dm.standardize_smiles(out)
    return out

encoder(inp, canonical=True, randomize=False, seed=None, constraints=None, allow_empty=False, rdkit_safe=True)

Convert input smiles to SAFE representation

Parameters:

Name Type Description Default
inp Union[str, Mol]

input smiles

required
canonical bool

whether to return canonical smiles string. Defaults to True

True
randomize Optional[bool]

whether to randomize the safe string encoding. Will be ignored if canonical is provided

False
seed Optional[int]

optional seed to use when allowing randomization of the SAFE encoding. Randomization happens at two steps: 1. at the original smiles representation by randomization the atoms. 2. at the SAFE conversion by randomizing fragment orders

None
constraints Optional[List[Mol]]

List of molecules or pattern to preserve during the SAFE construction. Any bond slicing would happen outside of a substructure matching one of the patterns.

None
allow_empty bool

whether to allow the slicing algorithm to return empty bonds

False
rdkit_safe bool

whether to apply rdkit-safe digit standardization to the output SAFE string.

True
Source code in safe/converter.py
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
def encoder(
    self,
    inp: Union[str, dm.Mol],
    canonical: bool = True,
    randomize: Optional[bool] = False,
    seed: Optional[int] = None,
    constraints: Optional[List[dm.Mol]] = None,
    allow_empty: bool = False,
    rdkit_safe: bool = True,
):
    """Convert input smiles to SAFE representation

    Args:
        inp: input smiles
        canonical: whether to return canonical smiles string. Defaults to True
        randomize: whether to randomize the safe string encoding. Will be ignored if canonical is provided
        seed: optional seed to use when allowing randomization of the SAFE encoding.
            Randomization happens at two steps:
            1. at the original smiles representation by randomization the atoms.
            2. at the SAFE conversion by randomizing fragment orders
        constraints: List of molecules or pattern to preserve during the SAFE construction. Any bond slicing would
            happen outside of a substructure matching one of the patterns.
        allow_empty: whether to allow the slicing algorithm to return empty bonds
        rdkit_safe: whether to apply rdkit-safe digit standardization to the output SAFE string.
    """
    rng = None
    if randomize:
        rng = np.random.default_rng(seed)
        if not canonical:
            inp = dm.to_mol(inp, remove_hs=False)
            inp = self.randomize(inp, rng)

    if isinstance(inp, dm.Mol):
        inp = dm.to_smiles(inp, canonical=canonical, randomize=False, ordered=False)

    # EN: we first normalize the attachment if the molecule is a query:
    # inp = dm.reactions.convert_attach_to_isotope(inp, as_smiles=True)

    # TODO(maclandrol): RDKit supports some extended form of ring closure, up to 5 digits
    # https://www.rdkit.org/docs/RDKit_Book.html#ring-closures and I should try to include them
    branch_numbers = self._find_branch_number(inp)

    mol = dm.to_mol(inp, remove_hs=False)
    potential_stereos = Chem.FindPotentialStereo(mol)
    has_stereo_bonds = any(x.type == Chem.StereoType.Bond_Double for x in potential_stereos)
    if self.ignore_stereo:
        mol = dm.remove_stereochemistry(mol)

    bond_map_id = 1
    for atom in mol.GetAtoms():
        if atom.GetAtomicNum() == 0:
            atom.SetAtomMapNum(0)
            atom.SetIsotope(bond_map_id)
            bond_map_id += 1

    if self.require_hs:
        mol = dm.add_hs(mol)
    matching_bonds = self._fragment(mol, allow_empty=allow_empty)
    substructed_ignored = []
    if constraints is not None:
        substructed_ignored = list(
            itertools.chain(
                *[
                    mol.GetSubstructMatches(constraint, uniquify=True)
                    for constraint in constraints
                ]
            )
        )

    bonds = []
    for i_a, i_b in matching_bonds:
        # if both atoms of the bond are found in a disallowed substructure, we cannot consider them
        # on the other end, a bond between two substructure to preserved independently is perfectly fine
        if any((i_a in ignore_x and i_b in ignore_x) for ignore_x in substructed_ignored):
            continue
        obond = mol.GetBondBetweenAtoms(i_a, i_b)
        bonds.append(obond.GetIdx())

    if len(bonds) > 0:
        mol = Chem.FragmentOnBonds(
            mol,
            bonds,
            dummyLabels=[(i + bond_map_id, i + bond_map_id) for i in range(len(bonds))],
        )
    # here we need to be clever and disable rooted atom as the atom with mapping

    frags = list(Chem.GetMolFrags(mol, asMols=True))
    if randomize:
        frags = rng.permutation(frags).tolist()
    elif canonical:
        frags = sorted(
            frags,
            key=lambda x: x.GetNumAtoms(),
            reverse=True,
        )

    frags_str = []
    for frag in frags:
        non_map_atom_idxs = [
            atom.GetIdx() for atom in frag.GetAtoms() if atom.GetAtomicNum() != 0
        ]
        frags_str.append(
            Chem.MolToSmiles(
                frag,
                isomericSmiles=True,
                canonical=True,  # needs to always be true
                rootedAtAtom=non_map_atom_idxs[0],
            )
        )

    scaffold_str = ".".join(frags_str)
    # EN: fix for https://github.com/datamol-io/safe/issues/37
    # we were using the wrong branch number count which did not take into account
    # possible change in digit utilization after bond slicing
    scf_branch_num = self._find_branch_number(scaffold_str) + branch_numbers

    # don't capture atom mapping in the scaffold
    attach_pos = set(re.findall(r"(\[\d+\*\]|!\[[^:]*:\d+\])", scaffold_str))
    if canonical:
        attach_pos = sorted(attach_pos)
    starting_num = 1 if len(scf_branch_num) == 0 else max(scf_branch_num) + 1
    for attach in attach_pos:
        val = str(starting_num) if starting_num < 10 else f"%{starting_num}"
        # we cannot have anything of the form "\([@=-#-$/\]*\d+\)"
        attach_regexp = re.compile(r"(" + re.escape(attach) + r")")
        scaffold_str = attach_regexp.sub(val, scaffold_str)
        starting_num += 1

    # now we need to remove all the parenthesis around digit only number
    wrong_attach = re.compile(r"\(([\%\d]*)\)")
    scaffold_str = wrong_attach.sub(r"\g<1>", scaffold_str)
    # furthermore, we autoapply rdkit-compatible digit standardization.
    if rdkit_safe:
        pattern = r"\(([=-@#\/\\]{0,2})(%?\d{1,2})\)"
        replacement = r"\g<1>\g<2>"
        scaffold_str = re.sub(pattern, replacement, scaffold_str)
    if not self.ignore_stereo and has_stereo_bonds and not dm.same_mol(scaffold_str, inp):
        logger.warning(
            "Ignoring stereo is disabled, but molecule has stereochemistry interferring with SAFE representation"
        )
    return scaffold_str

randomize(mol, rng=None) staticmethod

Randomize the position of the atoms in a mol.

Parameters:

Name Type Description Default
mol Mol

molecules to randomize

required
rng Optional[int]

optional seed to use

None
Source code in safe/converter.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
@staticmethod
def randomize(mol: dm.Mol, rng: Optional[int] = None):
    """Randomize the position of the atoms in a mol.

    Args:
        mol: molecules to randomize
        rng: optional seed to use
    """
    if isinstance(rng, int):
        rng = np.random.default_rng(rng)
    if mol.GetNumAtoms() == 0:
        return mol
    atom_indices = list(range(mol.GetNumAtoms()))
    atom_indices = rng.permutation(atom_indices).tolist()
    return Chem.RenumberAtoms(mol, atom_indices)

encode(inp, canonical=True, randomize=False, seed=None, slicer=None, require_hs=None, constraints=None, ignore_stereo=False)

Convert input smiles to SAFE representation

Parameters:

Name Type Description Default
inp Union[str, Mol]

input smiles

required
canonical bool

whether to return canonical SAFE string. Defaults to True

True
randomize Optional[bool]

whether to randomize the safe string encoding. Will be ignored if canonical is provided

False
seed Optional[int]

optional seed to use when allowing randomization of the SAFE encoding.

None
slicer Optional[Union[List[str], str, Callable]]

slicer algorithm to use for encoding. Defaults to "brics".

None
require_hs Optional[bool]

whether the slicing algorithm require the molecule to have hydrogen explictly added.

None
constraints Optional[List[Mol]]

List of molecules or pattern to preserve during the SAFE construction.

None
ignore_stereo Optional[bool]

RDKIT does not support some particular SAFE subset when stereochemistry is defined.

False
Source code in safe/converter.py
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
def encode(
    inp: Union[str, dm.Mol],
    canonical: bool = True,
    randomize: Optional[bool] = False,
    seed: Optional[int] = None,
    slicer: Optional[Union[List[str], str, Callable]] = None,
    require_hs: Optional[bool] = None,
    constraints: Optional[List[dm.Mol]] = None,
    ignore_stereo: Optional[bool] = False,
):
    """
    Convert input smiles to SAFE representation

    Args:
        inp: input smiles
        canonical: whether to return canonical SAFE string. Defaults to True
        randomize: whether to randomize the safe string encoding. Will be ignored if canonical is provided
        seed: optional seed to use when allowing randomization of the SAFE encoding.
        slicer: slicer algorithm to use for encoding. Defaults to "brics".
        require_hs: whether the slicing algorithm require the molecule to have hydrogen explictly added.
        constraints: List of molecules or pattern to preserve during the SAFE construction.
        ignore_stereo: RDKIT does not support some particular SAFE subset when stereochemistry is defined.
    """
    if slicer is None:
        slicer = "brics"
    with dm.without_rdkit_log():
        safe_obj = SAFEConverter(slicer=slicer, require_hs=require_hs, ignore_stereo=ignore_stereo)
        try:
            encoded = safe_obj.encoder(
                inp,
                canonical=canonical,
                randomize=randomize,
                constraints=constraints,
                seed=seed,
            )
        except SAFEFragmentationError as e:
            raise e
        except Exception as e:
            raise SAFEEncodeError(f"Failed to encode {inp} with {slicer}") from e
        return encoded

decode(safe_str, as_mol=False, canonical=False, fix=True, remove_added_hs=True, remove_dummies=True, ignore_errors=False)

Convert input SAFE representation to smiles Args: safe_str: input SAFE representation to decode as a valid molecule or smiles as_mol: whether to return a molecule object or a smiles string canonical: whether to return a canonical smiles or a randomized smiles fix: whether to fix the SAFE representation to take into account non-connected attachment points remove_added_hs: whether to remove the hydrogen atoms that have been added to fix the string. remove_dummies: whether to remove dummy atoms from the SAFE representation ignore_errors: whether to ignore error and return None on decoding failure or raise an error

Source code in safe/converter.py
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
def decode(
    safe_str: str,
    as_mol: bool = False,
    canonical: bool = False,
    fix: bool = True,
    remove_added_hs: bool = True,
    remove_dummies: bool = True,
    ignore_errors: bool = False,
):
    """Convert input SAFE representation to smiles
    Args:
        safe_str: input SAFE representation to decode as a valid molecule or smiles
        as_mol: whether to return a molecule object or a smiles string
        canonical: whether to return a canonical smiles or a randomized smiles
        fix: whether to fix the SAFE representation to take into account non-connected attachment points
        remove_added_hs: whether to remove the hydrogen atoms that have been added to fix the string.
        remove_dummies: whether to remove dummy atoms from the SAFE representation
        ignore_errors: whether to ignore error and return None on decoding failure or raise an error

    """
    with dm.without_rdkit_log():
        safe_obj = SAFEConverter()
        try:
            decoded = safe_obj.decoder(
                safe_str,
                as_mol=as_mol,
                canonical=canonical,
                fix=fix,
                remove_dummies=remove_dummies,
                remove_added_hs=remove_added_hs,
            )

        except Exception as e:
            if ignore_errors:
                return None
            raise SAFEDecodeError(f"Failed to decode {safe_str}") from e
        return decoded

SAFE Design

SAFEDesign

Molecular generation using SAFE pretrained model

Source code in safe/sample.py
  27
  28
  29
  30
  31
  32
  33
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
class SAFEDesign:
    """Molecular generation using SAFE pretrained model"""

    _DEFAULT_MAX_LENGTH = 1024  # default max length used during training
    _DEFAULT_MODEL_PATH = "datamol-io/safe-gpt"

    def __init__(
        self,
        model: Union[SAFEDoubleHeadsModel, str],
        tokenizer: Union[str, SAFETokenizer],
        generation_config: Optional[Union[str, GenerationConfig]] = None,
        safe_encoder: Optional[sf.SAFEConverter] = None,
        verbose: bool = True,
    ):
        """SAFEDesign constructor

        !!! info
            Design methods in SAFE are not deterministic when it comes to the token sampling step.
            If a method accepts a `random_seed`, it's for the SAFE-related algorithms and not the
            sampling from the autoregressive model. To ensure you get a deterministic sampling,
            please set the seed at the `transformers` package level.

            ```python
            import safe as sf
            import transformers
            my_seed = 100
            designer = sf.SAFEDesign(...)

            transformers.set_seed(100) # use this before calling a design function
            designer.linker_generation(...)
            ```


        Args:
            model: input SAFEDoubleHeadsModel to use for generation
            tokenizer: input SAFETokenizer to use for generation
            generation_config: input GenerationConfig to use for generation
            safe_encoder: custom safe encoder to use
            verbose: whether to print out logging information during generation
        """

        if isinstance(model, (str, os.PathLike)):
            model = SAFEDoubleHeadsModel.from_pretrained(model)

        if isinstance(tokenizer, (str, os.PathLike)):
            tokenizer = SAFETokenizer.load(tokenizer)

        model.eval()
        self.model = model
        self.tokenizer = tokenizer
        if isinstance(generation_config, os.PathLike):
            generation_config = GenerationConfig.from_pretrained(generation_config)
        if generation_config is None:
            generation_config = GenerationConfig.from_model_config(model.config)
        self.generation_config = generation_config
        for special_token_id in ["bos_token_id", "eos_token_id", "pad_token_id"]:
            if getattr(self.generation_config, special_token_id) is None:
                setattr(
                    self.generation_config, special_token_id, getattr(tokenizer, special_token_id)
                )

        self.verbose = verbose
        self.safe_encoder = safe_encoder or sf.SAFEConverter()

    @classmethod
    def load_from_wandb(
        cls, artifact_path: str, device: Optional[str] = None, verbose: bool = True, **kwargs: Any
    ) -> "SAFEDesign":
        """
        Load SAFE model and tokenizer from a Weights and Biases (wandb) artifact. By default, the model will be downloaded into SAFE_MODEL_ROOT.

        Args:
            artifact_path: The path to the wandb artifact in the format `entity/project/artifact:version`.
            device: The device where the model should be loaded ('cpu' or 'cuda'). If None, it defaults to the available device.
            verbose: Whether to print out logging information during generation.

        Returns:
            SAFEDesign: An instance of SAFEDesign class with the model, tokenizer, and generation config loaded from wandb.
        """
        # EN: potentially remove wandb scheme
        import wandb

        artifact_path = artifact_path.replace("wandb://", "")

        # Parse the artifact path to extract project and artifact name
        parts = artifact_path.split("/", 1)
        if len(parts) > 1:
            project_name, artifact_name = parts
        else:
            project_name = os.getenv("SAFE_WANDB_PROJECT", "safe-models")
            artifact_name = artifact_path

        if ":" not in artifact_name:
            artifact_name += ":latest"

        artifact_path = f"{project_name}/{artifact_name}"

        # Check if SAFE_MODEL_ROOT environment variable is defined
        cache_path = os.getenv("SAFE_MODEL_ROOT", None)
        if cache_path is not None:
            # Ensure the cache path exists
            cache_path = Path(cache_path)
            cache_path.mkdir(parents=True, exist_ok=True)
            artifact_subfolder = artifact_path.replace("/", "_").replace(":", "_")
            cache_dir = cache_path / artifact_subfolder
            cache_path = cache_dir.as_posix()

        api = wandb.Api()
        # Download the artifact from wandb to the cache directory
        artifact = api.artifact(artifact_path, type="model")
        artifact_dir = artifact.download(root=cache_path)

        # Load the model, tokenizer, and generation config from the artifact directory
        model = SAFEDoubleHeadsModel.from_pretrained(artifact_dir)
        tokenizer = SAFETokenizer.from_pretrained(artifact_dir)
        gen_config = GenerationConfig.from_pretrained(artifact_dir)

        # Move model to the specified device if provided
        if device is not None:
            model = model.to(device)

        return cls(
            model=model,
            tokenizer=tokenizer,
            generation_config=gen_config,
            verbose=verbose,
            **kwargs,
        )

    @classmethod
    def load_default(
        cls,
        model_dir: Optional[str] = None,
        device: str = None,
        verbose: bool = False,
        **kwargs: Any,
    ) -> "SAFEDesign":
        """Load default SAFEGenerator model

        Args:
            verbose: whether to print out logging information during generation
            model_dir: Optional path to model folder to use instead of the default one.
                If provided the tokenizer should be in the model_dir named as `tokenizer.json`
            device: optional device where to move the model
            kwargs: any additional argument to pass to the init function
        """
        if model_dir is None or not model_dir:
            model_dir = cls._DEFAULT_MODEL_PATH
        model = SAFEDoubleHeadsModel.from_pretrained(model_dir)
        tokenizer = SAFETokenizer.from_pretrained(model_dir)
        gen_config = GenerationConfig.from_pretrained(model_dir)
        if device is not None:
            model = model.to(device)
        return cls(
            model=model,
            tokenizer=tokenizer,
            generation_config=gen_config,
            verbose=verbose,
            **kwargs,
        )

    def linker_generation(
        self,
        *groups: Union[str, dm.Mol],
        n_samples_per_trial: int = 10,
        n_trials: Optional[int] = 1,
        sanitize: bool = False,
        do_not_fragment_further: Optional[bool] = True,
        random_seed: Optional[int] = None,
        model_only: Optional[bool] = False,
        **kwargs: Optional[Dict[Any, Any]],
    ):
        """Perform linker generation using the pretrained SAFE model.
        Linker generation is really just scaffold morphing underlying.

        Args:
            groups: list of fragments to link together, they are joined in the order provided
            n_samples_per_trial: number of new molecules to generate for each randomization
            n_trials: number of randomization to perform
            do_not_fragment_further: whether to fragment the scaffold further or not
            sanitize: whether to sanitize the generated molecules
            random_seed: random seed to use
            model_only: whether to use the model only ability and nothing more.
            kwargs: any argument to provide to the underlying generation function
        """
        side_chains = list(groups)

        if len(side_chains) != 2:
            raise ValueError(
                "Linker generation only works when providing two groups as side chains"
            )

        return self._fragment_linking(
            side_chains=side_chains,
            n_samples_per_trial=n_samples_per_trial,
            n_trials=n_trials,
            sanitize=sanitize,
            do_not_fragment_further=do_not_fragment_further,
            random_seed=random_seed,
            is_linking=True,
            model_only=model_only,
            **kwargs,
        )

    def scaffold_morphing(
        self,
        side_chains: Optional[Union[dm.Mol, str, List[Union[str, dm.Mol]]]] = None,
        mol: Optional[Union[dm.Mol, str]] = None,
        core: Optional[Union[dm.Mol, str]] = None,
        n_samples_per_trial: int = 10,
        n_trials: Optional[int] = 1,
        sanitize: bool = False,
        do_not_fragment_further: Optional[bool] = True,
        random_seed: Optional[int] = None,
        **kwargs: Optional[Dict[Any, Any]],
    ):
        """Perform scaffold morphing decoration using the pretrained SAFE model

        For scaffold morphing, we try to replace the core by a new one. If the side_chains are provided, we use them.
        If a combination of molecule and core is provided, then, we use them to extract the side chains and performing the
        scaffold morphing then.

        !!! note "Finding the side chains"
            The algorithm to find the side chains from core assumes that the core we get as input has attachment points.
            Those attachment points are never considered as part of the query, rather they are used to define the attachment points.
            See ~sf.utils.compute_side_chains for more information.

        Args:
            side_chains: side chains to use to perform scaffold morphing (joining as best as possible the set of fragments)
            mol: input molecules when side_chains are not provided
            core: core to morph into another scaffold
            n_samples_per_trial: number of new molecules to generate for each randomization
            n_trials: number of randomization to perform
            do_not_fragment_further: whether to fragment the scaffold further or not
            sanitize: whether to sanitize the generated molecules
            random_seed: random seed to use
            kwargs: any argument to provide to the underlying generation function
        """

        return self._fragment_linking(
            side_chains=side_chains,
            mol=mol,
            core=core,
            n_samples_per_trial=n_samples_per_trial,
            n_trials=n_trials,
            sanitize=sanitize,
            do_not_fragment_further=do_not_fragment_further,
            random_seed=random_seed,
            is_linking=False,
            **kwargs,
        )

    def _fragment_linking(
        self,
        side_chains: Optional[Union[dm.Mol, str, List[Union[str, dm.Mol]]]] = None,
        mol: Optional[Union[dm.Mol, str]] = None,
        core: Optional[Union[dm.Mol, str]] = None,
        n_samples_per_trial: int = 10,
        n_trials: Optional[int] = 1,
        sanitize: bool = False,
        do_not_fragment_further: Optional[bool] = False,
        random_seed: Optional[int] = None,
        is_linking: Optional[bool] = False,
        model_only: Optional[bool] = False,
        **kwargs: Optional[Dict[Any, Any]],
    ):
        """Perform scaffold morphing decoration using the pretrained SAFE model

        For scaffold morphing, we try to replace the core by a new one. If the side_chains are provided, we use them.
        If a combination of molecule and core is provided, then, we use them to extract the side chains and performing the
        scaffold morphing then.

        !!! note "Finding the side chains"
            The algorithm to find the side chains from core assumes that the core we get as input has attachment points.
            Those attachment points are never considered as part of the query, rather they are used to define the attachment points.
            See ~sf.utils.compute_side_chains for more information.

        Args:
            side_chains: side chains to use to perform scaffold morphing (joining as best as possible the set of fragments)
            mol: input molecules when side_chains are not provided
            core: core to morph into another scaffold
            n_samples_per_trial: number of new molecules to generate for each randomization
            n_trials: number of randomization to perform
            do_not_fragment_further: whether to fragment the scaffold further or not
            sanitize: whether to sanitize the generated molecules
            random_seed: random seed to use
            is_linking: whether it's a linking task or not.
                For linking tasks, we use a different custom strategy of completing up to the attachment signal
            model_only: whether to use the model only ability and nothing more. Only relevant when doing linker generation
            kwargs: any argument to provide to the underlying generation function
        """
        if side_chains is None:
            if mol is None and core is None:
                raise ValueError(
                    "Either side_chains OR mol+core should be provided for scaffold morphing"
                )
            side_chains = sf.trainer.utils.compute_side_chains(mol, core)
        side_chains = (
            [dm.to_mol(x) for x in side_chains]
            if isinstance(side_chains, list)
            else [dm.to_mol(side_chains)]
        )

        side_chains = ".".join([dm.to_smiles(x) for x in side_chains])

        if "*" not in side_chains and self.verbose:
            logger.warning(
                f"Side chain {side_chains} does not contain any dummy atoms, this might not be what you want"
            )

        rng = random.Random(random_seed)
        new_seed = rng.randint(1, 1000)

        total_sequences = []
        n_trials = n_trials or 1
        for _ in tqdm(range(n_trials), disable=(not self.verbose), leave=False):
            with dm.without_rdkit_log():
                context_mng = (
                    sf.utils.attr_as(self.safe_encoder, "slicer", None)
                    if do_not_fragment_further
                    else suppress()
                )
                old_slicer = getattr(self.safe_encoder, "slicer", None)
                with context_mng:
                    try:
                        encoded_fragment = self.safe_encoder.encoder(
                            side_chains,
                            canonical=False,
                            randomize=False,
                            constraints=None,
                            allow_empty=True,
                            seed=new_seed,
                        )

                    except Exception as e:
                        if self.verbose:
                            logger.error(e)
                        raise sf.SAFEEncodeError(f"Failed to encode {side_chains}") from e
                    finally:
                        if old_slicer is not None:
                            self.safe_encoder.slicer = old_slicer

            fragments = encoded_fragment.split(".")
            missing_closure = Counter(self.safe_encoder._find_branch_number(encoded_fragment))
            missing_closure = [f"{str(x)}" for x in missing_closure if missing_closure[x] % 2 == 1]

            closure_pos = [
                m.start() for x in missing_closure for m in re.finditer(x, encoded_fragment)
            ]
            fragment_pos = [m.start() for m in re.finditer(r"\.", encoded_fragment)]
            min_pos = 0
            while fragment_pos[min_pos] < closure_pos[0] and min_pos < len(fragment_pos):
                min_pos += 1
            min_pos += 1
            max_pos = len(fragment_pos)
            while fragment_pos[max_pos - 1] > closure_pos[-1] and max_pos > 0:
                max_pos -= 1

            split_index = rng.randint(min_pos, max_pos)
            prefix, suffixes = ".".join(fragments[:split_index]), ".".join(fragments[split_index:])

            missing_prefix_closure = Counter(self.safe_encoder._find_branch_number(prefix))
            missing_suffix_closure = Counter(self.safe_encoder._find_branch_number(suffixes))

            missing_prefix_closure = (
                ["."] + [x for x in missing_closure if int(x) not in missing_prefix_closure] + ["."]
            )
            missing_suffix_closure = (
                ["."] + [x for x in missing_closure if int(x) not in missing_suffix_closure] + ["."]
            )

            constraints_ids = []
            for permutation in itertools.permutations(missing_closure + ["."]):
                constraints_ids.append(
                    self.tokenizer.encode(list(permutation), add_special_tokens=False)
                )

            # prefix_constraints_ids = self.tokenizer.encode(missing_prefix_closure, add_special_tokens=False)
            # suffix_constraints_ids = self.tokenizer.encode(missing_suffix_closure, add_special_tokens=False)

            # suffix_ids = self.tokenizer.encode([suffixes+self.tokenizer.tokenizer.eos_token], add_special_tokens=False)
            # prefix_ids = self.tokenizer.encode([prefix], add_special_tokens=False)

            prefix_kwargs = kwargs.copy()
            suffix_kwargs = prefix_kwargs.copy()

            if is_linking and model_only:
                for _kwargs in [prefix_kwargs, suffix_kwargs]:
                    _kwargs.setdefault("how", "beam")
                    _kwargs.setdefault("num_beams", n_samples_per_trial)
                    _kwargs.setdefault("do_sample", False)

                prefix_kwargs["constraints"] = []
                suffix_kwargs["constraints"] = []
                # prefix_kwargs["constraints"] = [PhrasalConstraint(tkl) for tkl in suffix_constraints_ids]
                # suffix_kwargs["constraints"] = [PhrasalConstraint(tkl) for tkl in prefix_constraints_ids]

                # we first generate a part of the fragment with for unique constraint that it should contain
                # the closure required to join something to the suffix.
                prefix_kwargs["constraints"] += [
                    DisjunctiveConstraint(tkl) for tkl in constraints_ids
                ]
                suffix_kwargs["constraints"] += [
                    DisjunctiveConstraint(tkl) for tkl in constraints_ids
                ]

                prefix_sequences = self._generate(
                    n_samples=n_samples_per_trial, safe_prefix=prefix, **prefix_kwargs
                )
                suffix_sequences = self._generate(
                    n_samples=n_samples_per_trial, safe_prefix=suffixes, **suffix_kwargs
                )

                prefix_sequences = [
                    self._find_fragment_cut(x, prefix, missing_prefix_closure[1])
                    for x in prefix_sequences
                ]
                suffix_sequences = [
                    self._find_fragment_cut(x, suffixes, missing_suffix_closure[1])
                    for x in suffix_sequences
                ]

                linkers = [x for x in set(prefix_sequences + suffix_sequences) if x]
                sequences = [f"{prefix}.{linker}.{suffixes}" for linker in linkers]
                sequences += self._decode_safe(sequences, canonical=True, remove_invalid=sanitize)

            else:
                mol_linker_slicer = sf.utils.MolSlicer(
                    shortest_linker=(not is_linking), require_ring_system=(not is_linking)
                )
                prefix_smiles = sf.decode(prefix, remove_dummies=False, as_mol=False)
                suffix_smiles = sf.decode(suffixes, remove_dummies=False, as_mol=False)

                prefix_sequences = self._generate(
                    n_samples=n_samples_per_trial, safe_prefix=prefix + ".", **prefix_kwargs
                )
                suffix_sequences = self._generate(
                    n_samples=n_samples_per_trial, safe_prefix=suffixes + ".", **suffix_kwargs
                )

                prefix_sequences = self._decode_safe(
                    prefix_sequences, canonical=True, remove_invalid=True
                )
                suffix_sequences = self._decode_safe(
                    suffix_sequences, canonical=True, remove_invalid=True
                )
                sequences = self.__mix_sequences(
                    prefix_sequences,
                    suffix_sequences,
                    prefix_smiles,
                    suffix_smiles,
                    n_samples_per_trial,
                    mol_linker_slicer,
                )

            total_sequences.extend(sequences)

        # then we should filter out molecules that do not match the requested
        if sanitize:
            total_sequences = sf.utils.filter_by_substructure_constraints(
                total_sequences, side_chains
            )
            if self.verbose:
                logger.info(
                    f"After sanitization, {len(total_sequences)} / {n_samples_per_trial*n_trials} ({len(total_sequences)*100/(n_samples_per_trial*n_trials):.2f} %)  generated molecules are valid !"
                )
        return total_sequences

    def motif_extension(
        self,
        motif: Union[str, dm.Mol],
        n_samples_per_trial: int = 10,
        n_trials: Optional[int] = 1,
        sanitize: bool = False,
        do_not_fragment_further: Optional[bool] = True,
        random_seed: Optional[int] = None,
        **kwargs: Optional[Dict[Any, Any]],
    ):
        """Perform motif extension using the pretrained SAFE model.
        Motif extension is really just scaffold decoration underlying.

        Args:
            motif: scaffold (with attachment points) to decorate
            n_samples_per_trial: number of new molecules to generate for each randomization
            n_trials: number of randomization to perform
            do_not_fragment_further: whether to fragment the scaffold further or not
            sanitize: whether to sanitize the generated molecules and check
            random_seed: random seed to use
            kwargs: any argument to provide to the underlying generation function
        """
        return self.scaffold_decoration(
            motif,
            n_samples_per_trial=n_samples_per_trial,
            n_trials=n_trials,
            sanitize=sanitize,
            do_not_fragment_further=do_not_fragment_further,
            random_seed=random_seed,
            add_dot=True,
            **kwargs,
        )

    def super_structure(
        self,
        core: Union[str, dm.Mol],
        n_samples_per_trial: int = 10,
        n_trials: Optional[int] = 1,
        sanitize: bool = False,
        do_not_fragment_further: Optional[bool] = True,
        random_seed: Optional[int] = None,
        attachment_point_depth: Optional[int] = None,
        **kwargs: Optional[Dict[Any, Any]],
    ):
        """Perform super structure generation using the pretrained SAFE model.

        To generate super-structure, we basically just create various attachment points to the input core,
        then perform scaffold decoration.

        Args:
            core: input substructure to use. We aim to generate super structures of this molecule
            n_samples_per_trial: number of new molecules to generate for each randomization
            n_trials: number of different attachment points to consider
            do_not_fragment_further: whether to fragment the scaffold further or not
            sanitize: whether to sanitize the generated molecules
            random_seed: random seed to use
            attachment_point_depth: depth of opening the attachment points.
                Increasing this, means you increase the number of substitution point to consider.
            kwargs: any argument to provide to the underlying generation function
        """

        core = dm.to_mol(core)
        cores = sf.utils.list_individual_attach_points(core, depth=attachment_point_depth)
        # get the fully open mol, everytime too.
        cores.append(dm.to_smiles(dm.reactions.open_attach_points(core)))
        cores = list(set(cores))
        rng = random.Random(random_seed)
        rng.shuffle(cores)
        # now also get the single openining of an attachment point
        total_sequences = []
        n_trials = n_trials or 1
        for _ in tqdm(range(n_trials), disable=(not self.verbose), leave=False):
            core = cores[_ % len(cores)]
            old_verbose = self.verbose
            try:
                with sf.utils.attr_as(self, "verbose", False):
                    out = self._completion(
                        fragment=core,
                        n_samples_per_trial=n_samples_per_trial,
                        n_trials=1,
                        do_not_fragment_further=do_not_fragment_further,
                        sanitize=sanitize,
                        random_seed=random_seed,
                        **kwargs,
                    )
                    total_sequences.extend(out)
            except Exception as e:
                if old_verbose:
                    logger.error(e)

            finally:
                self.verbose = old_verbose

        if sanitize and self.verbose:
            logger.info(
                f"After sanitization, {len(total_sequences)} / {n_samples_per_trial*n_trials} ({len(total_sequences)*100/(n_samples_per_trial*n_trials):.2f} %)  generated molecules are valid !"
            )
        return total_sequences

    def scaffold_decoration(
        self,
        scaffold: Union[str, dm.Mol],
        n_samples_per_trial: int = 10,
        n_trials: Optional[int] = 1,
        do_not_fragment_further: Optional[bool] = True,
        sanitize: bool = False,
        random_seed: Optional[int] = None,
        add_dot: Optional[bool] = True,
        **kwargs: Optional[Dict[Any, Any]],
    ):
        """Perform scaffold decoration using the pretrained SAFE model

        For scaffold decoration, we basically starts with a prefix with the attachment point.
        We first convert the prefix into valid safe string.

        Args:
            scaffold: scaffold (with attachment points) to decorate
            n_samples_per_trial: number of new molecules to generate for each randomization
            n_trials: number of randomization to perform
            do_not_fragment_further: whether to fragment the scaffold further or not
            sanitize: whether to sanitize the generated molecules and check if the scaffold is still present
            random_seed: random seed to use
            kwargs: any argument to provide to the underlying generation function
        """

        total_sequences = self._completion(
            fragment=scaffold,
            n_samples_per_trial=n_samples_per_trial,
            n_trials=n_trials,
            do_not_fragment_further=do_not_fragment_further,
            sanitize=sanitize,
            random_seed=random_seed,
            add_dot=add_dot,
            **kwargs,
        )
        # if we require sanitization
        # then we should filter out molecules that do not match the requested
        if sanitize:
            total_sequences = sf.utils.filter_by_substructure_constraints(total_sequences, scaffold)
            if self.verbose:
                logger.info(
                    f"After sanitization, {len(total_sequences)} / {n_samples_per_trial*n_trials} ({len(total_sequences)*100/(n_samples_per_trial*n_trials):.2f} %)  generated molecules are valid !"
                )
        return total_sequences

    def pattern_decoration(
        self,
        scaffold: Union[str, dm.Mol],
        n_samples_per_trial: int = 10,
        n_trials: int = 1,
        do_not_fragment_further: bool = True,
        sanitize: bool = False,
        random_seed: Optional[int] = None,
        add_dot: bool = True,
        n_scaff_random: Optional[int] = 3,
        n_scaff_samples: Optional[int] = 10,
        scaff_temperature: float = 1.0,
        **kwargs: Optional[Dict[Any, Any]],
    ) -> List[str]:
        """
        Perform pattern decoration using the pretrained SAFE model. The pattern decoration algorithm works by first examplifying the patterns
        as a set of scaffold then performing scaffold decoration on each scaffold.

        !!! warning
            Designing molecules from a given molecule pattern is more challenging than fragment-constrained design.
            SAFE does not currently support complex SMARTS pattern schemes (e.g., valence or connectivity constraints, some ring constraints).
            This function works best when sampling given a list of atoms. However, sampling depends on the model's conditional probabilities,
            meaning that if the model assigns zero probability to a token, you are unlikely to see it.

        Args:
            scaffold: Scaffold (with attachment points) to decorate.
            n_samples_per_trial: Number of new molecules to generate for each randomization.
            n_trials: Number of randomizations to perform.
            do_not_fragment_further: Whether to prevent further fragmentation of the scaffold.
            sanitize: Whether to sanitize the generated molecules and ensure the scaffold is present.
            random_seed: Seed for randomization.
            n_scaff_random: Number of scaffold randomizations to try (to reposition constraints in the string and increase rollout likelihood).
                Increasing this will improve sampling, but will require more time.
            n_scaff_samples: Maximum number of samples to sample for a given scaffold from the pattern.
                Increasing this will make sure you have more diversity in the scaffold coming from the pattern
            scaff_temperature: Temperature to use when sampling valid scaffolds from the pattern. Higher temperature means more diverse scaffold
            kwargs: Additional arguments for the underlying generation function.

        Returns:
            List of decorated molecule sequences.
        """

        smarts_scaffolds = [scaffold]
        if n_scaff_random and n_scaff_random > 0:
            smarts_scaffolds = PatternConstraint.randomize(scaffold, n_scaff_random)

        all_scaffolds = set()
        for sm in smarts_scaffolds:
            cur_dec_pattern = PatternConstraint(sm, self.tokenizer, temperature=scaff_temperature)
            decorator = PatternSampler(self.model, cur_dec_pattern)
            cur_scaffolds = decorator.sample_scaffolds(
                n_samples=min(n_samples_per_trial, n_scaff_samples),
                n_trials=1,
                random_seed=random_seed,
            )
            all_scaffolds.update(cur_scaffolds)

        if sanitize:
            all_scaffolds = [x for x in all_scaffolds if dm.from_smarts(x) is not None]
        total_sequences = []
        for scaff in all_scaffolds:
            with suppress(Exception), dm.without_rdkit_log():
                cur_sequences = self._completion(
                    fragment=dm.from_smarts(scaff),
                    n_samples_per_trial=int(n_samples_per_trial / max(len(all_scaffolds), 1)) + 1,
                    n_trials=n_trials,
                    do_not_fragment_further=do_not_fragment_further,
                    sanitize=sanitize,
                    random_seed=random_seed,
                    add_dot=add_dot,
                    **kwargs,
                )
                total_sequences.extend(cur_sequences)

        random.shuffle(total_sequences)
        if sanitize:
            total_sequences = sf.utils.filter_by_substructure_constraints(total_sequences, scaffold)
            total_sequences = total_sequences[:n_samples_per_trial]
            if self.verbose:
                logger.info(
                    f"After sanitization, {len(total_sequences)} / {n_samples_per_trial * n_trials} "
                    f"({len(total_sequences) * 100 / (n_samples_per_trial * n_trials):.2f}%) generated molecules are valid!"
                )

        return total_sequences[:n_samples_per_trial]

    def de_novo_generation(
        self,
        n_samples_per_trial: int = 10,
        sanitize: bool = False,
        n_trials: Optional[int] = None,
        **kwargs: Optional[Dict[Any, Any]],
    ):
        """Perform de novo generation using the pretrained SAFE model.

        De novo generation is equivalent to not having any prefix.

        Args:
            n_samples_per_trial: number of new molecules to generate
            sanitize: whether to perform sanitization, aka, perform control to ensure what is asked is what is returned
            n_trials: number of randomization to perform
            kwargs: any argument to provide to the underlying generation function
        """
        # EN: lazy programming much ?
        kwargs.setdefault("how", "random")
        if kwargs["how"] != "random" and not kwargs.get("do_sample"):
            logger.warning(
                "I don't think you know what you are doing ... for de novo generation `do_sample=True` or `how='random'` is expected !"
            )

        total_sequences = []
        n_trials = n_trials or 1
        for _ in tqdm(range(n_trials), disable=(not self.verbose), leave=False):
            sequences = self._generate(n_samples=n_samples_per_trial, **kwargs)
            total_sequences.extend(sequences)
        total_sequences = self._decode_safe(
            total_sequences, canonical=True, remove_invalid=sanitize
        )

        if sanitize and self.verbose:
            logger.info(
                f"After sanitization, {len(total_sequences)} / {n_samples_per_trial*n_trials} ({len(total_sequences)*100/(n_samples_per_trial*n_trials):.2f} %) generated molecules are valid !"
            )
        return total_sequences

    def _find_fragment_cut(self, fragment: str, prefix_constraint: str, branching_id: str):
        """
        Perform a cut on the input fragment in such a way that it could be joined with another fragments sharing the same
        branching id.

        Args:
            fragment: fragment to cut
            prefix_constraint: prefix constraint to use
            branching_id: branching id to use
        """
        prefix_constraint = prefix_constraint.rstrip(".") + "."
        fragment = (
            fragment.replace(prefix_constraint, "", 1)
            if fragment.startswith(prefix_constraint)
            else fragment
        )
        fragments = fragment.split(".")
        i = 0
        for x in fragments:
            if branching_id in x:
                i += 1
                break
        return ".".join(fragments[:i])

    def __mix_sequences(
        self,
        prefix_sequences: List[str],
        suffix_sequences: List[str],
        prefix: str,
        suffix: str,
        n_samples: int,
        mol_linker_slicer,
    ):
        """Use generated prefix and suffix sequences to form new molecules
        that will be the merging of both. This is the two step scaffold morphing and linker generation scheme
        Args:
            prefix_sequences: list of prefix sequences
            suffix_sequences: list of suffix sequences
            prefix: decoded smiles of the prefix
            suffix: decoded smiles of the suffix
            n_samples: number of samples to generate
        """
        prefix_linkers = []
        suffix_linkers = []
        prefix_query = dm.from_smarts(prefix)
        suffix_query = dm.from_smarts(suffix)

        for x in prefix_sequences:
            with suppress(Exception):
                x = dm.to_mol(x)
                out = mol_linker_slicer(x, prefix_query)
                prefix_linkers.append(out[1])
        for x in suffix_sequences:
            with suppress(Exception):
                x = dm.to_mol(x)
                out = mol_linker_slicer(x, suffix_query)
                suffix_linkers.append(out[1])
        n_linked = 0
        linked = []
        linkers = prefix_linkers + suffix_linkers
        linkers = [x for x in linkers if x is not None]
        for n_linked, linker in enumerate(linkers):
            linked.extend(mol_linker_slicer.link_fragments(linker, prefix, suffix))
            if n_linked > n_samples:
                break
            linked = [x for x in linked if x]
        return linked[:n_samples]

    def _decode_safe(
        self, sequences: List[str], canonical: bool = True, remove_invalid: bool = False
    ):
        """Decode a safe sequence into a molecule

        Args:
            sequence: safe sequence to decode
            canonical: whether to return canonical sequence
            remove_invalid: whether to remove invalid safe strings or keep them
        """

        def _decode_fn(x):
            return sf.decode(
                x,
                as_mol=False,
                fix=True,
                remove_added_hs=True,
                canonical=canonical,
                ignore_errors=True,
                remove_dummies=True,
            )

        if len(sequences) > 100:
            safe_strings = dm.parallelized(_decode_fn, sequences, n_jobs=-1)
        else:
            safe_strings = [_decode_fn(x) for x in sequences]
        if remove_invalid:
            safe_strings = [x for x in safe_strings if x is not None]

        return safe_strings

    def _completion(
        self,
        fragment: Union[str, dm.Mol],
        n_samples_per_trial: int = 10,
        n_trials: Optional[int] = 1,
        do_not_fragment_further: Optional[bool] = False,
        sanitize: bool = False,
        random_seed: Optional[int] = None,
        add_dot: Optional[bool] = False,
        is_safe: Optional[bool] = False,
        **kwargs,
    ):
        """Perform sentence completion using a prefix fragment

        Args:
            fragment: fragment (with attachment points)
            n_samples_per_trial: number of new molecules to generate for each randomization
            n_trials: number of randomization to perform
            do_not_fragment_further: whether to fragment the scaffold further or not
            sanitize: whether to sanitize the generated molecules
            random_seed: random seed to use
            is_safe: whether the smiles is already encoded as a safe string
            add_dot: whether to add a dot at the end of the fragments to signal to the model that we want to generate a distinct fragment.
            kwargs: any argument to provide to the underlying generation function
        """

        # EN: lazy programming much ?
        kwargs.setdefault("how", "random")
        if kwargs["how"] != "random" and not kwargs.get("do_sample"):
            logger.warning(
                "I don't think you know what you are doing ... for de novo generation `do_sample=True` or `how='random'` is expected !"
            )

        # Step 1: we conver the fragment into the relevant safe string format
        # we use the provided safe encoder with the slicer that was expected

        rng = random.Random(random_seed)
        new_seed = rng.randint(1, 1000)

        total_sequences = []
        n_trials = n_trials or 1
        for _ in tqdm(range(n_trials), disable=(not self.verbose), leave=False):
            if is_safe:
                encoded_fragment = fragment
            else:
                with dm.without_rdkit_log():
                    context_mng = (
                        sf.utils.attr_as(self.safe_encoder, "slicer", None)
                        if do_not_fragment_further
                        else suppress()
                    )
                    old_slicer = getattr(self.safe_encoder, "slicer", None)
                    with context_mng:
                        try:
                            encoded_fragment = self.safe_encoder.encoder(
                                fragment,
                                canonical=False,
                                randomize=True,
                                constraints=None,
                                allow_empty=True,
                                seed=new_seed,
                            )

                        except Exception as e:
                            if self.verbose:
                                logger.error(e)
                            raise sf.SAFEEncodeError(f"Failed to encode {fragment}") from e
                        finally:
                            if old_slicer is not None:
                                self.safe_encoder.slicer = old_slicer

            if add_dot and encoded_fragment.count("(") == encoded_fragment.count(")"):
                encoded_fragment = encoded_fragment.rstrip(".") + "."

            sequences = self._generate(
                n_samples=n_samples_per_trial, safe_prefix=encoded_fragment, **kwargs
            )

            sequences = self._decode_safe(sequences, canonical=True, remove_invalid=sanitize)
            total_sequences.extend(sequences)

        return total_sequences

    def _generate(
        self,
        n_samples: int = 1,
        safe_prefix: Optional[str] = None,
        max_length: Optional[int] = 100,
        how: Optional[str] = "random",
        num_beams: Optional[int] = None,
        num_beam_groups: Optional[int] = None,
        do_sample: Optional[bool] = None,
        **kwargs,
    ):
        """Sample a new sequence using the underlying hugging face model.
        This emulates the izanagi sampling models, if you wish to retain the hugging face generation
        behaviour, either call the hugging face functions directly or overwrite this function

        ??? note "Generation Parameters"
            From the hugging face documentation:

            * `greedy decoding` if how="greedy" and num_beams=1 and do_sample=False.
            * `multinomial sampling` if num_beams=1 and do_sample=True.
            * `beam-search decoding` if how="beam" and num_beams>1 and do_sample=False.
            * `beam-search multinomial` sampling by calling if beam=True, num_beams>1 and do_sample=True or how="random" and num_beams>1
            * `diverse beam-search decoding` if num_beams>1 and num_beam_groups>1

            It's also possible to ignore the 'how' shortcut and directly call the underlying generation methods using the proper arguments.
            Learn more here: https://huggingface.co/docs/transformers/v4.32.0/en/main_classes/text_generation#transformers.GenerationConfig
            Under the hood, the following will be applied depending on the arguments:

            * greedy decoding by calling greedy_search() if num_beams=1 and do_sample=False
            * contrastive search by calling contrastive_search() if penalty_alpha>0. and top_k>1
            * multinomial sampling by calling sample() if num_beams=1 and do_sample=True
            * beam-search decoding by calling beam_search() if num_beams>1 and do_sample=False
            * beam-search multinomial sampling by calling beam_sample() if num_beams>1 and do_sample=True
            * diverse beam-search decoding by calling group_beam_search(), if num_beams>1 and num_beam_groups>1
            * constrained beam-search decoding by calling constrained_beam_search(), if constraints!=None or force_words_ids!=None
            * assisted decoding by calling assisted_decoding(), if assistant_model is passed to .generate()

        Args:
            n_samples: number of sequences to return
            safe_prefix: Prefix to use in sampling, should correspond to a safe fragment
            max_length : maximum length of sampled sequence
            how: which sampling method to use: "beam", "greedy" or "random". Can be used to control other parameters by setting defaults
            num_beams: number of beams for beam search. 1 means no beam search, unless beam is specified then max(n_samples, num_beams) is used
            num_beam_groups: number of beam groups for diverse beam search
            do_sample: whether to perform random sampling or not, equivalent to setting random to True
            kwargs: any additional keyword argument to pass to the underlying sampling `generate`  from hugging face transformer

        Returns:
            samples: list of sampled molecules, including failed validation

        """
        pretrained_tk = self.tokenizer.get_pretrained()
        if getattr(pretrained_tk, "model_max_length") is None:
            setattr(
                pretrained_tk,
                "model_max_length",
                self._DEFAULT_MAX_LENGTH,  # this was the defaul
            )

        input_ids = safe_prefix
        if isinstance(safe_prefix, str):
            # EN: should we address the special token issues
            input_ids = pretrained_tk(
                safe_prefix,
                return_tensors="pt",
            )

        num_beams = num_beams or None
        do_sample = do_sample or False

        if how == "random":
            do_sample = True

        elif how is not None and "beam" in how:
            num_beams = max((num_beams or 0), n_samples)

        is_greedy = how == "greedy" or (num_beams in [0, 1, None]) and do_sample is False

        kwargs["do_sample"] = do_sample
        if num_beams is not None:
            kwargs["num_beams"] = num_beams
        if num_beam_groups is not None:
            kwargs["num_beam_groups"] = num_beam_groups
        kwargs["output_scores"] = True
        kwargs["return_dict_in_generate"] = True
        kwargs["num_return_sequences"] = n_samples
        kwargs["max_length"] = max_length
        kwargs.setdefault("early_stopping", True)
        # EN we don't do anything with the score that the model might return on generate ...
        if not isinstance(input_ids, Mapping):
            input_ids = {"inputs": None}
        else:
            # EN: we remove the EOS token added before running the prediction
            # because the model output nonsense when we keep it.
            for k in input_ids:
                input_ids[k] = input_ids[k][:, :-1]

        for k, v in input_ids.items():
            if torch.is_tensor(v):
                input_ids[k] = v.to(self.model.device)

        # we remove the token_type_ids to support more model type than just GPT2
        input_ids.pop("token_type_ids", None)

        if is_greedy:
            kwargs["num_return_sequences"] = 1
            if num_beams is not None and num_beams > 1:
                raise ValueError("Cannot set num_beams|num_beam_groups > 1 for greedy")
            # under greedy decoding there can only be a single solution
            # we just duplicate the solution several time for efficiency
            outputs = self.model.generate(
                **input_ids,
                generation_config=self.generation_config,
                **kwargs,
            )
            sequences = [
                pretrained_tk.decode(outputs.sequences.squeeze(), skip_special_tokens=True)
            ] * n_samples

        else:
            outputs = self.model.generate(
                **input_ids,
                generation_config=self.generation_config,
                **kwargs,
            )
            sequences = pretrained_tk.batch_decode(outputs.sequences, skip_special_tokens=True)
        return sequences

__init__(model, tokenizer, generation_config=None, safe_encoder=None, verbose=True)

SAFEDesign constructor

Info

Design methods in SAFE are not deterministic when it comes to the token sampling step. If a method accepts a random_seed, it's for the SAFE-related algorithms and not the sampling from the autoregressive model. To ensure you get a deterministic sampling, please set the seed at the transformers package level.

import safe as sf
import transformers
my_seed = 100
designer = sf.SAFEDesign(...)

transformers.set_seed(100) # use this before calling a design function
designer.linker_generation(...)

Parameters:

Name Type Description Default
model Union[SAFEDoubleHeadsModel, str]

input SAFEDoubleHeadsModel to use for generation

required
tokenizer Union[str, SAFETokenizer]

input SAFETokenizer to use for generation

required
generation_config Optional[Union[str, GenerationConfig]]

input GenerationConfig to use for generation

None
safe_encoder Optional[SAFEConverter]

custom safe encoder to use

None
verbose bool

whether to print out logging information during generation

True
Source code in safe/sample.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def __init__(
    self,
    model: Union[SAFEDoubleHeadsModel, str],
    tokenizer: Union[str, SAFETokenizer],
    generation_config: Optional[Union[str, GenerationConfig]] = None,
    safe_encoder: Optional[sf.SAFEConverter] = None,
    verbose: bool = True,
):
    """SAFEDesign constructor

    !!! info
        Design methods in SAFE are not deterministic when it comes to the token sampling step.
        If a method accepts a `random_seed`, it's for the SAFE-related algorithms and not the
        sampling from the autoregressive model. To ensure you get a deterministic sampling,
        please set the seed at the `transformers` package level.

        ```python
        import safe as sf
        import transformers
        my_seed = 100
        designer = sf.SAFEDesign(...)

        transformers.set_seed(100) # use this before calling a design function
        designer.linker_generation(...)
        ```


    Args:
        model: input SAFEDoubleHeadsModel to use for generation
        tokenizer: input SAFETokenizer to use for generation
        generation_config: input GenerationConfig to use for generation
        safe_encoder: custom safe encoder to use
        verbose: whether to print out logging information during generation
    """

    if isinstance(model, (str, os.PathLike)):
        model = SAFEDoubleHeadsModel.from_pretrained(model)

    if isinstance(tokenizer, (str, os.PathLike)):
        tokenizer = SAFETokenizer.load(tokenizer)

    model.eval()
    self.model = model
    self.tokenizer = tokenizer
    if isinstance(generation_config, os.PathLike):
        generation_config = GenerationConfig.from_pretrained(generation_config)
    if generation_config is None:
        generation_config = GenerationConfig.from_model_config(model.config)
    self.generation_config = generation_config
    for special_token_id in ["bos_token_id", "eos_token_id", "pad_token_id"]:
        if getattr(self.generation_config, special_token_id) is None:
            setattr(
                self.generation_config, special_token_id, getattr(tokenizer, special_token_id)
            )

    self.verbose = verbose
    self.safe_encoder = safe_encoder or sf.SAFEConverter()

__mix_sequences(prefix_sequences, suffix_sequences, prefix, suffix, n_samples, mol_linker_slicer)

Use generated prefix and suffix sequences to form new molecules that will be the merging of both. This is the two step scaffold morphing and linker generation scheme Args: prefix_sequences: list of prefix sequences suffix_sequences: list of suffix sequences prefix: decoded smiles of the prefix suffix: decoded smiles of the suffix n_samples: number of samples to generate

Source code in safe/sample.py
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
def __mix_sequences(
    self,
    prefix_sequences: List[str],
    suffix_sequences: List[str],
    prefix: str,
    suffix: str,
    n_samples: int,
    mol_linker_slicer,
):
    """Use generated prefix and suffix sequences to form new molecules
    that will be the merging of both. This is the two step scaffold morphing and linker generation scheme
    Args:
        prefix_sequences: list of prefix sequences
        suffix_sequences: list of suffix sequences
        prefix: decoded smiles of the prefix
        suffix: decoded smiles of the suffix
        n_samples: number of samples to generate
    """
    prefix_linkers = []
    suffix_linkers = []
    prefix_query = dm.from_smarts(prefix)
    suffix_query = dm.from_smarts(suffix)

    for x in prefix_sequences:
        with suppress(Exception):
            x = dm.to_mol(x)
            out = mol_linker_slicer(x, prefix_query)
            prefix_linkers.append(out[1])
    for x in suffix_sequences:
        with suppress(Exception):
            x = dm.to_mol(x)
            out = mol_linker_slicer(x, suffix_query)
            suffix_linkers.append(out[1])
    n_linked = 0
    linked = []
    linkers = prefix_linkers + suffix_linkers
    linkers = [x for x in linkers if x is not None]
    for n_linked, linker in enumerate(linkers):
        linked.extend(mol_linker_slicer.link_fragments(linker, prefix, suffix))
        if n_linked > n_samples:
            break
        linked = [x for x in linked if x]
    return linked[:n_samples]

de_novo_generation(n_samples_per_trial=10, sanitize=False, n_trials=None, **kwargs)

Perform de novo generation using the pretrained SAFE model.

De novo generation is equivalent to not having any prefix.

Parameters:

Name Type Description Default
n_samples_per_trial int

number of new molecules to generate

10
sanitize bool

whether to perform sanitization, aka, perform control to ensure what is asked is what is returned

False
n_trials Optional[int]

number of randomization to perform

None
kwargs Optional[Dict[Any, Any]]

any argument to provide to the underlying generation function

{}
Source code in safe/sample.py
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
def de_novo_generation(
    self,
    n_samples_per_trial: int = 10,
    sanitize: bool = False,
    n_trials: Optional[int] = None,
    **kwargs: Optional[Dict[Any, Any]],
):
    """Perform de novo generation using the pretrained SAFE model.

    De novo generation is equivalent to not having any prefix.

    Args:
        n_samples_per_trial: number of new molecules to generate
        sanitize: whether to perform sanitization, aka, perform control to ensure what is asked is what is returned
        n_trials: number of randomization to perform
        kwargs: any argument to provide to the underlying generation function
    """
    # EN: lazy programming much ?
    kwargs.setdefault("how", "random")
    if kwargs["how"] != "random" and not kwargs.get("do_sample"):
        logger.warning(
            "I don't think you know what you are doing ... for de novo generation `do_sample=True` or `how='random'` is expected !"
        )

    total_sequences = []
    n_trials = n_trials or 1
    for _ in tqdm(range(n_trials), disable=(not self.verbose), leave=False):
        sequences = self._generate(n_samples=n_samples_per_trial, **kwargs)
        total_sequences.extend(sequences)
    total_sequences = self._decode_safe(
        total_sequences, canonical=True, remove_invalid=sanitize
    )

    if sanitize and self.verbose:
        logger.info(
            f"After sanitization, {len(total_sequences)} / {n_samples_per_trial*n_trials} ({len(total_sequences)*100/(n_samples_per_trial*n_trials):.2f} %) generated molecules are valid !"
        )
    return total_sequences

linker_generation(*groups, n_samples_per_trial=10, n_trials=1, sanitize=False, do_not_fragment_further=True, random_seed=None, model_only=False, **kwargs)

Perform linker generation using the pretrained SAFE model. Linker generation is really just scaffold morphing underlying.

Parameters:

Name Type Description Default
groups Union[str, Mol]

list of fragments to link together, they are joined in the order provided

()
n_samples_per_trial int

number of new molecules to generate for each randomization

10
n_trials Optional[int]

number of randomization to perform

1
do_not_fragment_further Optional[bool]

whether to fragment the scaffold further or not

True
sanitize bool

whether to sanitize the generated molecules

False
random_seed Optional[int]

random seed to use

None
model_only Optional[bool]

whether to use the model only ability and nothing more.

False
kwargs Optional[Dict[Any, Any]]

any argument to provide to the underlying generation function

{}
Source code in safe/sample.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def linker_generation(
    self,
    *groups: Union[str, dm.Mol],
    n_samples_per_trial: int = 10,
    n_trials: Optional[int] = 1,
    sanitize: bool = False,
    do_not_fragment_further: Optional[bool] = True,
    random_seed: Optional[int] = None,
    model_only: Optional[bool] = False,
    **kwargs: Optional[Dict[Any, Any]],
):
    """Perform linker generation using the pretrained SAFE model.
    Linker generation is really just scaffold morphing underlying.

    Args:
        groups: list of fragments to link together, they are joined in the order provided
        n_samples_per_trial: number of new molecules to generate for each randomization
        n_trials: number of randomization to perform
        do_not_fragment_further: whether to fragment the scaffold further or not
        sanitize: whether to sanitize the generated molecules
        random_seed: random seed to use
        model_only: whether to use the model only ability and nothing more.
        kwargs: any argument to provide to the underlying generation function
    """
    side_chains = list(groups)

    if len(side_chains) != 2:
        raise ValueError(
            "Linker generation only works when providing two groups as side chains"
        )

    return self._fragment_linking(
        side_chains=side_chains,
        n_samples_per_trial=n_samples_per_trial,
        n_trials=n_trials,
        sanitize=sanitize,
        do_not_fragment_further=do_not_fragment_further,
        random_seed=random_seed,
        is_linking=True,
        model_only=model_only,
        **kwargs,
    )

load_default(model_dir=None, device=None, verbose=False, **kwargs) classmethod

Load default SAFEGenerator model

Parameters:

Name Type Description Default
verbose bool

whether to print out logging information during generation

False
model_dir Optional[str]

Optional path to model folder to use instead of the default one. If provided the tokenizer should be in the model_dir named as tokenizer.json

None
device str

optional device where to move the model

None
kwargs Any

any additional argument to pass to the init function

{}
Source code in safe/sample.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
@classmethod
def load_default(
    cls,
    model_dir: Optional[str] = None,
    device: str = None,
    verbose: bool = False,
    **kwargs: Any,
) -> "SAFEDesign":
    """Load default SAFEGenerator model

    Args:
        verbose: whether to print out logging information during generation
        model_dir: Optional path to model folder to use instead of the default one.
            If provided the tokenizer should be in the model_dir named as `tokenizer.json`
        device: optional device where to move the model
        kwargs: any additional argument to pass to the init function
    """
    if model_dir is None or not model_dir:
        model_dir = cls._DEFAULT_MODEL_PATH
    model = SAFEDoubleHeadsModel.from_pretrained(model_dir)
    tokenizer = SAFETokenizer.from_pretrained(model_dir)
    gen_config = GenerationConfig.from_pretrained(model_dir)
    if device is not None:
        model = model.to(device)
    return cls(
        model=model,
        tokenizer=tokenizer,
        generation_config=gen_config,
        verbose=verbose,
        **kwargs,
    )

load_from_wandb(artifact_path, device=None, verbose=True, **kwargs) classmethod

Load SAFE model and tokenizer from a Weights and Biases (wandb) artifact. By default, the model will be downloaded into SAFE_MODEL_ROOT.

Parameters:

Name Type Description Default
artifact_path str

The path to the wandb artifact in the format entity/project/artifact:version.

required
device Optional[str]

The device where the model should be loaded ('cpu' or 'cuda'). If None, it defaults to the available device.

None
verbose bool

Whether to print out logging information during generation.

True

Returns:

Name Type Description
SAFEDesign SAFEDesign

An instance of SAFEDesign class with the model, tokenizer, and generation config loaded from wandb.

Source code in safe/sample.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
@classmethod
def load_from_wandb(
    cls, artifact_path: str, device: Optional[str] = None, verbose: bool = True, **kwargs: Any
) -> "SAFEDesign":
    """
    Load SAFE model and tokenizer from a Weights and Biases (wandb) artifact. By default, the model will be downloaded into SAFE_MODEL_ROOT.

    Args:
        artifact_path: The path to the wandb artifact in the format `entity/project/artifact:version`.
        device: The device where the model should be loaded ('cpu' or 'cuda'). If None, it defaults to the available device.
        verbose: Whether to print out logging information during generation.

    Returns:
        SAFEDesign: An instance of SAFEDesign class with the model, tokenizer, and generation config loaded from wandb.
    """
    # EN: potentially remove wandb scheme
    import wandb

    artifact_path = artifact_path.replace("wandb://", "")

    # Parse the artifact path to extract project and artifact name
    parts = artifact_path.split("/", 1)
    if len(parts) > 1:
        project_name, artifact_name = parts
    else:
        project_name = os.getenv("SAFE_WANDB_PROJECT", "safe-models")
        artifact_name = artifact_path

    if ":" not in artifact_name:
        artifact_name += ":latest"

    artifact_path = f"{project_name}/{artifact_name}"

    # Check if SAFE_MODEL_ROOT environment variable is defined
    cache_path = os.getenv("SAFE_MODEL_ROOT", None)
    if cache_path is not None:
        # Ensure the cache path exists
        cache_path = Path(cache_path)
        cache_path.mkdir(parents=True, exist_ok=True)
        artifact_subfolder = artifact_path.replace("/", "_").replace(":", "_")
        cache_dir = cache_path / artifact_subfolder
        cache_path = cache_dir.as_posix()

    api = wandb.Api()
    # Download the artifact from wandb to the cache directory
    artifact = api.artifact(artifact_path, type="model")
    artifact_dir = artifact.download(root=cache_path)

    # Load the model, tokenizer, and generation config from the artifact directory
    model = SAFEDoubleHeadsModel.from_pretrained(artifact_dir)
    tokenizer = SAFETokenizer.from_pretrained(artifact_dir)
    gen_config = GenerationConfig.from_pretrained(artifact_dir)

    # Move model to the specified device if provided
    if device is not None:
        model = model.to(device)

    return cls(
        model=model,
        tokenizer=tokenizer,
        generation_config=gen_config,
        verbose=verbose,
        **kwargs,
    )

motif_extension(motif, n_samples_per_trial=10, n_trials=1, sanitize=False, do_not_fragment_further=True, random_seed=None, **kwargs)

Perform motif extension using the pretrained SAFE model. Motif extension is really just scaffold decoration underlying.

Parameters:

Name Type Description Default
motif Union[str, Mol]

scaffold (with attachment points) to decorate

required
n_samples_per_trial int

number of new molecules to generate for each randomization

10
n_trials Optional[int]

number of randomization to perform

1
do_not_fragment_further Optional[bool]

whether to fragment the scaffold further or not

True
sanitize bool

whether to sanitize the generated molecules and check

False
random_seed Optional[int]

random seed to use

None
kwargs Optional[Dict[Any, Any]]

any argument to provide to the underlying generation function

{}
Source code in safe/sample.py
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
def motif_extension(
    self,
    motif: Union[str, dm.Mol],
    n_samples_per_trial: int = 10,
    n_trials: Optional[int] = 1,
    sanitize: bool = False,
    do_not_fragment_further: Optional[bool] = True,
    random_seed: Optional[int] = None,
    **kwargs: Optional[Dict[Any, Any]],
):
    """Perform motif extension using the pretrained SAFE model.
    Motif extension is really just scaffold decoration underlying.

    Args:
        motif: scaffold (with attachment points) to decorate
        n_samples_per_trial: number of new molecules to generate for each randomization
        n_trials: number of randomization to perform
        do_not_fragment_further: whether to fragment the scaffold further or not
        sanitize: whether to sanitize the generated molecules and check
        random_seed: random seed to use
        kwargs: any argument to provide to the underlying generation function
    """
    return self.scaffold_decoration(
        motif,
        n_samples_per_trial=n_samples_per_trial,
        n_trials=n_trials,
        sanitize=sanitize,
        do_not_fragment_further=do_not_fragment_further,
        random_seed=random_seed,
        add_dot=True,
        **kwargs,
    )

pattern_decoration(scaffold, n_samples_per_trial=10, n_trials=1, do_not_fragment_further=True, sanitize=False, random_seed=None, add_dot=True, n_scaff_random=3, n_scaff_samples=10, scaff_temperature=1.0, **kwargs)

Perform pattern decoration using the pretrained SAFE model. The pattern decoration algorithm works by first examplifying the patterns as a set of scaffold then performing scaffold decoration on each scaffold.

Warning

Designing molecules from a given molecule pattern is more challenging than fragment-constrained design. SAFE does not currently support complex SMARTS pattern schemes (e.g., valence or connectivity constraints, some ring constraints). This function works best when sampling given a list of atoms. However, sampling depends on the model's conditional probabilities, meaning that if the model assigns zero probability to a token, you are unlikely to see it.

Parameters:

Name Type Description Default
scaffold Union[str, Mol]

Scaffold (with attachment points) to decorate.

required
n_samples_per_trial int

Number of new molecules to generate for each randomization.

10
n_trials int

Number of randomizations to perform.

1
do_not_fragment_further bool

Whether to prevent further fragmentation of the scaffold.

True
sanitize bool

Whether to sanitize the generated molecules and ensure the scaffold is present.

False
random_seed Optional[int]

Seed for randomization.

None
n_scaff_random Optional[int]

Number of scaffold randomizations to try (to reposition constraints in the string and increase rollout likelihood). Increasing this will improve sampling, but will require more time.

3
n_scaff_samples Optional[int]

Maximum number of samples to sample for a given scaffold from the pattern. Increasing this will make sure you have more diversity in the scaffold coming from the pattern

10
scaff_temperature float

Temperature to use when sampling valid scaffolds from the pattern. Higher temperature means more diverse scaffold

1.0
kwargs Optional[Dict[Any, Any]]

Additional arguments for the underlying generation function.

{}

Returns:

Type Description
List[str]

List of decorated molecule sequences.

Source code in safe/sample.py
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
def pattern_decoration(
    self,
    scaffold: Union[str, dm.Mol],
    n_samples_per_trial: int = 10,
    n_trials: int = 1,
    do_not_fragment_further: bool = True,
    sanitize: bool = False,
    random_seed: Optional[int] = None,
    add_dot: bool = True,
    n_scaff_random: Optional[int] = 3,
    n_scaff_samples: Optional[int] = 10,
    scaff_temperature: float = 1.0,
    **kwargs: Optional[Dict[Any, Any]],
) -> List[str]:
    """
    Perform pattern decoration using the pretrained SAFE model. The pattern decoration algorithm works by first examplifying the patterns
    as a set of scaffold then performing scaffold decoration on each scaffold.

    !!! warning
        Designing molecules from a given molecule pattern is more challenging than fragment-constrained design.
        SAFE does not currently support complex SMARTS pattern schemes (e.g., valence or connectivity constraints, some ring constraints).
        This function works best when sampling given a list of atoms. However, sampling depends on the model's conditional probabilities,
        meaning that if the model assigns zero probability to a token, you are unlikely to see it.

    Args:
        scaffold: Scaffold (with attachment points) to decorate.
        n_samples_per_trial: Number of new molecules to generate for each randomization.
        n_trials: Number of randomizations to perform.
        do_not_fragment_further: Whether to prevent further fragmentation of the scaffold.
        sanitize: Whether to sanitize the generated molecules and ensure the scaffold is present.
        random_seed: Seed for randomization.
        n_scaff_random: Number of scaffold randomizations to try (to reposition constraints in the string and increase rollout likelihood).
            Increasing this will improve sampling, but will require more time.
        n_scaff_samples: Maximum number of samples to sample for a given scaffold from the pattern.
            Increasing this will make sure you have more diversity in the scaffold coming from the pattern
        scaff_temperature: Temperature to use when sampling valid scaffolds from the pattern. Higher temperature means more diverse scaffold
        kwargs: Additional arguments for the underlying generation function.

    Returns:
        List of decorated molecule sequences.
    """

    smarts_scaffolds = [scaffold]
    if n_scaff_random and n_scaff_random > 0:
        smarts_scaffolds = PatternConstraint.randomize(scaffold, n_scaff_random)

    all_scaffolds = set()
    for sm in smarts_scaffolds:
        cur_dec_pattern = PatternConstraint(sm, self.tokenizer, temperature=scaff_temperature)
        decorator = PatternSampler(self.model, cur_dec_pattern)
        cur_scaffolds = decorator.sample_scaffolds(
            n_samples=min(n_samples_per_trial, n_scaff_samples),
            n_trials=1,
            random_seed=random_seed,
        )
        all_scaffolds.update(cur_scaffolds)

    if sanitize:
        all_scaffolds = [x for x in all_scaffolds if dm.from_smarts(x) is not None]
    total_sequences = []
    for scaff in all_scaffolds:
        with suppress(Exception), dm.without_rdkit_log():
            cur_sequences = self._completion(
                fragment=dm.from_smarts(scaff),
                n_samples_per_trial=int(n_samples_per_trial / max(len(all_scaffolds), 1)) + 1,
                n_trials=n_trials,
                do_not_fragment_further=do_not_fragment_further,
                sanitize=sanitize,
                random_seed=random_seed,
                add_dot=add_dot,
                **kwargs,
            )
            total_sequences.extend(cur_sequences)

    random.shuffle(total_sequences)
    if sanitize:
        total_sequences = sf.utils.filter_by_substructure_constraints(total_sequences, scaffold)
        total_sequences = total_sequences[:n_samples_per_trial]
        if self.verbose:
            logger.info(
                f"After sanitization, {len(total_sequences)} / {n_samples_per_trial * n_trials} "
                f"({len(total_sequences) * 100 / (n_samples_per_trial * n_trials):.2f}%) generated molecules are valid!"
            )

    return total_sequences[:n_samples_per_trial]

scaffold_decoration(scaffold, n_samples_per_trial=10, n_trials=1, do_not_fragment_further=True, sanitize=False, random_seed=None, add_dot=True, **kwargs)

Perform scaffold decoration using the pretrained SAFE model

For scaffold decoration, we basically starts with a prefix with the attachment point. We first convert the prefix into valid safe string.

Parameters:

Name Type Description Default
scaffold Union[str, Mol]

scaffold (with attachment points) to decorate

required
n_samples_per_trial int

number of new molecules to generate for each randomization

10
n_trials Optional[int]

number of randomization to perform

1
do_not_fragment_further Optional[bool]

whether to fragment the scaffold further or not

True
sanitize bool

whether to sanitize the generated molecules and check if the scaffold is still present

False
random_seed Optional[int]

random seed to use

None
kwargs Optional[Dict[Any, Any]]

any argument to provide to the underlying generation function

{}
Source code in safe/sample.py
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
def scaffold_decoration(
    self,
    scaffold: Union[str, dm.Mol],
    n_samples_per_trial: int = 10,
    n_trials: Optional[int] = 1,
    do_not_fragment_further: Optional[bool] = True,
    sanitize: bool = False,
    random_seed: Optional[int] = None,
    add_dot: Optional[bool] = True,
    **kwargs: Optional[Dict[Any, Any]],
):
    """Perform scaffold decoration using the pretrained SAFE model

    For scaffold decoration, we basically starts with a prefix with the attachment point.
    We first convert the prefix into valid safe string.

    Args:
        scaffold: scaffold (with attachment points) to decorate
        n_samples_per_trial: number of new molecules to generate for each randomization
        n_trials: number of randomization to perform
        do_not_fragment_further: whether to fragment the scaffold further or not
        sanitize: whether to sanitize the generated molecules and check if the scaffold is still present
        random_seed: random seed to use
        kwargs: any argument to provide to the underlying generation function
    """

    total_sequences = self._completion(
        fragment=scaffold,
        n_samples_per_trial=n_samples_per_trial,
        n_trials=n_trials,
        do_not_fragment_further=do_not_fragment_further,
        sanitize=sanitize,
        random_seed=random_seed,
        add_dot=add_dot,
        **kwargs,
    )
    # if we require sanitization
    # then we should filter out molecules that do not match the requested
    if sanitize:
        total_sequences = sf.utils.filter_by_substructure_constraints(total_sequences, scaffold)
        if self.verbose:
            logger.info(
                f"After sanitization, {len(total_sequences)} / {n_samples_per_trial*n_trials} ({len(total_sequences)*100/(n_samples_per_trial*n_trials):.2f} %)  generated molecules are valid !"
            )
    return total_sequences

scaffold_morphing(side_chains=None, mol=None, core=None, n_samples_per_trial=10, n_trials=1, sanitize=False, do_not_fragment_further=True, random_seed=None, **kwargs)

Perform scaffold morphing decoration using the pretrained SAFE model

For scaffold morphing, we try to replace the core by a new one. If the side_chains are provided, we use them. If a combination of molecule and core is provided, then, we use them to extract the side chains and performing the scaffold morphing then.

Finding the side chains

The algorithm to find the side chains from core assumes that the core we get as input has attachment points. Those attachment points are never considered as part of the query, rather they are used to define the attachment points. See ~sf.utils.compute_side_chains for more information.

Parameters:

Name Type Description Default
side_chains Optional[Union[Mol, str, List[Union[str, Mol]]]]

side chains to use to perform scaffold morphing (joining as best as possible the set of fragments)

None
mol Optional[Union[Mol, str]]

input molecules when side_chains are not provided

None
core Optional[Union[Mol, str]]

core to morph into another scaffold

None
n_samples_per_trial int

number of new molecules to generate for each randomization

10
n_trials Optional[int]

number of randomization to perform

1
do_not_fragment_further Optional[bool]

whether to fragment the scaffold further or not

True
sanitize bool

whether to sanitize the generated molecules

False
random_seed Optional[int]

random seed to use

None
kwargs Optional[Dict[Any, Any]]

any argument to provide to the underlying generation function

{}
Source code in safe/sample.py
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
def scaffold_morphing(
    self,
    side_chains: Optional[Union[dm.Mol, str, List[Union[str, dm.Mol]]]] = None,
    mol: Optional[Union[dm.Mol, str]] = None,
    core: Optional[Union[dm.Mol, str]] = None,
    n_samples_per_trial: int = 10,
    n_trials: Optional[int] = 1,
    sanitize: bool = False,
    do_not_fragment_further: Optional[bool] = True,
    random_seed: Optional[int] = None,
    **kwargs: Optional[Dict[Any, Any]],
):
    """Perform scaffold morphing decoration using the pretrained SAFE model

    For scaffold morphing, we try to replace the core by a new one. If the side_chains are provided, we use them.
    If a combination of molecule and core is provided, then, we use them to extract the side chains and performing the
    scaffold morphing then.

    !!! note "Finding the side chains"
        The algorithm to find the side chains from core assumes that the core we get as input has attachment points.
        Those attachment points are never considered as part of the query, rather they are used to define the attachment points.
        See ~sf.utils.compute_side_chains for more information.

    Args:
        side_chains: side chains to use to perform scaffold morphing (joining as best as possible the set of fragments)
        mol: input molecules when side_chains are not provided
        core: core to morph into another scaffold
        n_samples_per_trial: number of new molecules to generate for each randomization
        n_trials: number of randomization to perform
        do_not_fragment_further: whether to fragment the scaffold further or not
        sanitize: whether to sanitize the generated molecules
        random_seed: random seed to use
        kwargs: any argument to provide to the underlying generation function
    """

    return self._fragment_linking(
        side_chains=side_chains,
        mol=mol,
        core=core,
        n_samples_per_trial=n_samples_per_trial,
        n_trials=n_trials,
        sanitize=sanitize,
        do_not_fragment_further=do_not_fragment_further,
        random_seed=random_seed,
        is_linking=False,
        **kwargs,
    )

super_structure(core, n_samples_per_trial=10, n_trials=1, sanitize=False, do_not_fragment_further=True, random_seed=None, attachment_point_depth=None, **kwargs)

Perform super structure generation using the pretrained SAFE model.

To generate super-structure, we basically just create various attachment points to the input core, then perform scaffold decoration.

Parameters:

Name Type Description Default
core Union[str, Mol]

input substructure to use. We aim to generate super structures of this molecule

required
n_samples_per_trial int

number of new molecules to generate for each randomization

10
n_trials Optional[int]

number of different attachment points to consider

1
do_not_fragment_further Optional[bool]

whether to fragment the scaffold further or not

True
sanitize bool

whether to sanitize the generated molecules

False
random_seed Optional[int]

random seed to use

None
attachment_point_depth Optional[int]

depth of opening the attachment points. Increasing this, means you increase the number of substitution point to consider.

None
kwargs Optional[Dict[Any, Any]]

any argument to provide to the underlying generation function

{}
Source code in safe/sample.py
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
def super_structure(
    self,
    core: Union[str, dm.Mol],
    n_samples_per_trial: int = 10,
    n_trials: Optional[int] = 1,
    sanitize: bool = False,
    do_not_fragment_further: Optional[bool] = True,
    random_seed: Optional[int] = None,
    attachment_point_depth: Optional[int] = None,
    **kwargs: Optional[Dict[Any, Any]],
):
    """Perform super structure generation using the pretrained SAFE model.

    To generate super-structure, we basically just create various attachment points to the input core,
    then perform scaffold decoration.

    Args:
        core: input substructure to use. We aim to generate super structures of this molecule
        n_samples_per_trial: number of new molecules to generate for each randomization
        n_trials: number of different attachment points to consider
        do_not_fragment_further: whether to fragment the scaffold further or not
        sanitize: whether to sanitize the generated molecules
        random_seed: random seed to use
        attachment_point_depth: depth of opening the attachment points.
            Increasing this, means you increase the number of substitution point to consider.
        kwargs: any argument to provide to the underlying generation function
    """

    core = dm.to_mol(core)
    cores = sf.utils.list_individual_attach_points(core, depth=attachment_point_depth)
    # get the fully open mol, everytime too.
    cores.append(dm.to_smiles(dm.reactions.open_attach_points(core)))
    cores = list(set(cores))
    rng = random.Random(random_seed)
    rng.shuffle(cores)
    # now also get the single openining of an attachment point
    total_sequences = []
    n_trials = n_trials or 1
    for _ in tqdm(range(n_trials), disable=(not self.verbose), leave=False):
        core = cores[_ % len(cores)]
        old_verbose = self.verbose
        try:
            with sf.utils.attr_as(self, "verbose", False):
                out = self._completion(
                    fragment=core,
                    n_samples_per_trial=n_samples_per_trial,
                    n_trials=1,
                    do_not_fragment_further=do_not_fragment_further,
                    sanitize=sanitize,
                    random_seed=random_seed,
                    **kwargs,
                )
                total_sequences.extend(out)
        except Exception as e:
            if old_verbose:
                logger.error(e)

        finally:
            self.verbose = old_verbose

    if sanitize and self.verbose:
        logger.info(
            f"After sanitization, {len(total_sequences)} / {n_samples_per_trial*n_trials} ({len(total_sequences)*100/(n_samples_per_trial*n_trials):.2f} %)  generated molecules are valid !"
        )
    return total_sequences

SAFE Tokenizer

SAFESplitter

Standard Splitter for SAFE string

Source code in safe/tokenizer.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
class SAFESplitter:
    """Standard Splitter for SAFE string"""

    REGEX_PATTERN = r"""(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"""

    name = "safe"

    def __init__(self, pattern: Optional[str] = None):
        # do not use this as raw strings (not r before)
        if pattern is None:
            pattern = self.REGEX_PATTERN
        self.regex = re.compile(pattern)

    def tokenize(self, line):
        """Tokenize a safe string into characters."""
        if isinstance(line, str):
            tokens = list(self.regex.findall(line))
            reconstruction = "".join(tokens)
            if line != reconstruction:
                logger.error(
                    f"Tokens different from sample:\ntokens {reconstruction}\nsample {line}."
                )
                raise ValueError(line)
        else:
            idxs = re.finditer(self.regex, str(line))
            tokens = [line[m.start(0) : m.end(0)] for m in idxs]
        return tokens

    def detokenize(self, chars):
        """Detokenize SAFE notation"""
        if isinstance(chars, str):
            chars = chars.split(" ")
        return "".join([x.strip() for x in chars])

    def split(self, n, normalized):
        """Perform splitting for pretokenization"""
        return self.tokenize(normalized)

    def pre_tokenize(self, pretok):
        """Pretokenize using an input pretokenizer object from the tokenizer library"""
        pretok.split(self.split)

detokenize(chars)

Detokenize SAFE notation

Source code in safe/tokenizer.py
75
76
77
78
79
def detokenize(self, chars):
    """Detokenize SAFE notation"""
    if isinstance(chars, str):
        chars = chars.split(" ")
    return "".join([x.strip() for x in chars])

pre_tokenize(pretok)

Pretokenize using an input pretokenizer object from the tokenizer library

Source code in safe/tokenizer.py
85
86
87
def pre_tokenize(self, pretok):
    """Pretokenize using an input pretokenizer object from the tokenizer library"""
    pretok.split(self.split)

split(n, normalized)

Perform splitting for pretokenization

Source code in safe/tokenizer.py
81
82
83
def split(self, n, normalized):
    """Perform splitting for pretokenization"""
    return self.tokenize(normalized)

tokenize(line)

Tokenize a safe string into characters.

Source code in safe/tokenizer.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def tokenize(self, line):
    """Tokenize a safe string into characters."""
    if isinstance(line, str):
        tokens = list(self.regex.findall(line))
        reconstruction = "".join(tokens)
        if line != reconstruction:
            logger.error(
                f"Tokens different from sample:\ntokens {reconstruction}\nsample {line}."
            )
            raise ValueError(line)
    else:
        idxs = re.finditer(self.regex, str(line))
        tokens = [line[m.start(0) : m.end(0)] for m in idxs]
    return tokens

SAFETokenizer

Bases: PushToHubMixin

Class to initialize and train a tokenizer for SAFE string Once trained, you can use the converted version of the tokenizer to an HuggingFace PreTrainedTokenizerFast

Source code in safe/tokenizer.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
class SAFETokenizer(PushToHubMixin):
    """
    Class to initialize and train a tokenizer for SAFE string
    Once trained, you can use the converted version of the tokenizer to an HuggingFace PreTrainedTokenizerFast
    """

    vocab_files_names: str = "tokenizer.json"

    def __init__(
        self,
        tokenizer_type: str = "bpe",
        splitter: Optional[str] = "safe",
        trainer_args=None,
        decoder_args=None,
        token_model_args=None,
    ):
        super().__init__()
        self.tokenizer_type = tokenizer_type
        self.trainer_args = trainer_args or {}
        self.decoder_args = decoder_args or {}
        self.token_model_args = token_model_args or {}
        if tokenizer_type is not None and tokenizer_type.startswith("bpe"):
            self.model = BPE(unk_token=UNK_TOKEN, **self.token_model_args)
            self.trainer = BpeTrainer(special_tokens=SPECIAL_TOKENS, **self.trainer_args)

        else:
            self.model = WordLevel(unk_token=UNK_TOKEN, **self.token_model_args)
            self.trainer = WordLevelTrainer(special_tokens=SPECIAL_TOKENS, **self.trainer_args)

        self.tokenizer = Tokenizer(self.model)
        self.splitter = None
        if splitter == "safe":
            self.splitter = SAFESplitter()
            self.tokenizer.pre_tokenizer = PreTokenizer.custom(self.splitter)
        self.tokenizer.post_processor = TemplateProcessing(
            single=TEMPLATE_SINGLE,
            pair=TEMPLATE_PAIR,
            special_tokens=TEMPLATE_SPECIAL_TOKENS,
        )
        self.tokenizer.decoder = decoders.BPEDecoder(**self.decoder_args)
        self.tokenizer = self.set_special_tokens(self.tokenizer)

    @property
    def bos_token_id(self):
        """Get the bos token id"""
        return self.tokenizer.token_to_id(self.tokenizer.bos_token)

    @property
    def pad_token_id(self):
        """Get the bos token id"""
        return self.tokenizer.token_to_id(self.tokenizer.pad_token)

    @property
    def eos_token_id(self):
        """Get the bos token id"""
        return self.tokenizer.token_to_id(self.tokenizer.eos_token)

    @classmethod
    def set_special_tokens(
        cls,
        tokenizer: Tokenizer,
        bos_token: str = CLS_TOKEN,
        eos_token: str = SEP_TOKEN,
    ):
        """Set special tokens for a tokenizer

        Args:
            tokenizer: tokenizer for which special tokens will be set
            bos_token: Optional bos token to use
            eos_token: Optional eos token to use
        """
        tokenizer.pad_token = PADDING_TOKEN
        tokenizer.cls_token = CLS_TOKEN
        tokenizer.sep_token = SEP_TOKEN
        tokenizer.mask_token = MASK_TOKEN
        tokenizer.unk_token = UNK_TOKEN
        tokenizer.eos_token = eos_token
        tokenizer.bos_token = bos_token

        if isinstance(tokenizer, Tokenizer):
            tokenizer.add_special_tokens(
                [
                    PADDING_TOKEN,
                    CLS_TOKEN,
                    SEP_TOKEN,
                    MASK_TOKEN,
                    UNK_TOKEN,
                    eos_token,
                    bos_token,
                ]
            )
        return tokenizer

    def train(self, files: Optional[List[str]], **kwargs):
        r"""
        This is to train a new tokenizer from either a list of file or some input data

        Args
            files (str): file in which your molecules are separated by new line
            kwargs (dict): optional args for the tokenizer `train`
        """
        if isinstance(files, str):
            files = [files]
        self.tokenizer.train(files=files, trainer=self.trainer)

    def __getstate__(self):
        """Getting state to allow pickling"""
        with attr_as(self.tokenizer, "pre_tokenizer", Whitespace()):
            d = copy.deepcopy(self.__dict__)
        # copy back tokenizer level attribute
        d["tokenizer_attrs"] = self.tokenizer.__dict__.copy()
        d["tokenizer"].pre_tokenizer = Whitespace()
        return d

    def __setstate__(self, d):
        """Setting state during reloading pickling"""
        use_pretokenizer = d.get("custom_pre_tokenizer")
        if use_pretokenizer:
            d["tokenizer"].pre_tokenizer = PreTokenizer.custom(SAFESplitter())
        d["tokenizer"].__dict__.update(d.get("tokenizer_attrs", {}))
        self.__dict__.update(d)

    def train_from_iterator(self, data: Iterator, **kwargs: Any):
        """Train the Tokenizer using the provided iterator.

        You can provide anything that is a Python Iterator
            * A list of sequences :obj:`List[str]`
            * A generator that yields :obj:`str` or :obj:`List[str]`
            * A Numpy array of strings

        Args:
            data: data iterator
            **kwargs: additional keyword argument for the tokenizer `train_from_iterator`
        """
        self.tokenizer.train_from_iterator(data, trainer=self.trainer, **kwargs)

    def __len__(self):
        r"""
        Gets the count of tokens in vocab along with special tokens.
        """
        return len(self.tokenizer.get_vocab().keys())

    def encode(self, sample_str: str, ids_only: bool = True, **kwargs) -> list:
        r"""
        Encodes a given molecule string once training is done

        Args:
            sample_str: Sample string to encode molecule
            ids_only: whether to return only the ids or the encoding objet

        Returns:
            object: Returns encoded list of IDs
        """
        if isinstance(sample_str, str):
            enc = self.tokenizer.encode(sample_str, **kwargs)
            if ids_only:
                return enc.ids
            return enc

        encs = self.tokenizer.encode_batch(sample_str, **kwargs)
        if ids_only:
            return [enc.ids for enc in encs]
        return encs

    def to_dict(self, **kwargs):
        """Convert tokenizer to dict"""
        # we need to do this because HuggingFace tokenizers doesnt save with custom pre-tokenizers
        if self.splitter is None:
            tk_data = json.loads(self.tokenizer.to_str())
        else:
            with attr_as(self.tokenizer, "pre_tokenizer", Whitespace()):
                # temporary replace pre tokenizer with whitespace
                tk_data = json.loads(self.tokenizer.to_str())
                tk_data["custom_pre_tokenizer"] = True
        tk_data["tokenizer_type"] = self.tokenizer_type
        tk_data["tokenizer_attrs"] = self.tokenizer.__dict__
        return tk_data

    def save_pretrained(self, *args, **kwargs):
        """Save pretrained tokenizer"""
        self.tokenizer.save_pretrained(*args, **kwargs)

    def save(self, file_name=None):
        r"""
        Saves the :class:`~tokenizers.Tokenizer` to the file at the given path.

        Args:
            file_name (str, optional): File where to save tokenizer
        """
        # EN: whole logic here assumes noone is going to mess with the special token
        tk_data = self.to_dict()
        with fsspec.open(file_name, "w", encoding="utf-8") as OUT:
            out_str = json.dumps(tk_data, ensure_ascii=False)
            OUT.write(out_str)

    @classmethod
    def from_dict(cls, data: dict):
        """Load tokenizer from dict

        Args:
            data: dictionary containing the tokenizer info
        """
        tokenizer_type = data.pop("tokenizer_type", "safe")
        tokenizer_attrs = data.pop("tokenizer_attrs", None)
        custom_pre_tokenizer = data.pop("custom_pre_tokenizer", False)
        tokenizer = Tokenizer.from_str(json.dumps(data))
        if custom_pre_tokenizer:
            tokenizer.pre_tokenizer = PreTokenizer.custom(SAFESplitter())
        mol_tokenizer = cls(tokenizer_type)
        mol_tokenizer.tokenizer = mol_tokenizer.set_special_tokens(tokenizer)
        if tokenizer_attrs and isinstance(tokenizer_attrs, dict):
            mol_tokenizer.tokenizer.__dict__.update(tokenizer_attrs)
        return mol_tokenizer

    @classmethod
    def load(cls, file_name):
        """Load the current tokenizer from file"""
        with fsspec.open(file_name, "r") as OUT:
            data_str = OUT.read()
        data = json.loads(data_str)
        # EN: the rust json parser of tokenizers has a predefined structure
        # the next two lines are important
        return cls.from_dict(data)

    def decode(
        self,
        ids: list,
        skip_special_tokens: bool = True,
        ignore_stops: bool = False,
        stop_token_ids: Optional[List[int]] = None,
    ) -> str:
        r"""
        Decodes a list of ids to molecular representation in the format in which this tokenizer was created.

        Args:
            ids: list of IDs
            skip_special_tokens: whether to skip all special tokens when encountering them
            ignore_stops: whether to ignore the stop tokens, thus decoding till the end
            stop_token_ids: optional list of stop token ids to use

        Returns:
            sequence: str representation of molecule
        """
        old_id_list = ids
        if not isinstance(ids[0], (list, np.ndarray)) and not torch.is_tensor(ids[0]):
            old_id_list = [ids]
        if not stop_token_ids:
            stop_token_ids = [self.tokenizer.token_to_id(self.tokenizer.eos_token)]

        new_ids_list = []
        for ids in old_id_list:
            new_ids = ids
            if not ignore_stops:
                new_ids = []
                # if first tokens are stop, we just remove it
                # this is because of bart essentially
                pos = 0
                if len(ids) > 1:
                    while ids[pos] in stop_token_ids:
                        pos += 1
                # we only ignore when there is a list of tokens
                ids = ids[pos:]
                for pos, id in enumerate(ids):
                    if int(id) in stop_token_ids:
                        break
                    new_ids.append(id)
            new_ids_list.append(new_ids)
        if len(new_ids_list) == 1:
            return self.tokenizer.decode(
                list(new_ids_list[0]), skip_special_tokens=skip_special_tokens
            )
        return self.tokenizer.decode_batch(
            list(new_ids_list), skip_special_tokens=skip_special_tokens
        )

    def get_pretrained(self, **kwargs) -> PreTrainedTokenizerFast:
        r"""
        Get a pretrained tokenizer from this tokenizer

        Returns:
            Returns pre-trained fast tokenizer for hugging face models.
        """
        with attr_as(self.tokenizer, "pre_tokenizer", Whitespace()):
            tk = PreTrainedTokenizerFast(tokenizer_object=self.tokenizer)
        tk._tokenizer.pre_tokenizer = self.tokenizer.pre_tokenizer
        # now we need to add special_tokens
        tk.add_special_tokens(
            {
                "cls_token": self.tokenizer.cls_token,
                "bos_token": self.tokenizer.bos_token,
                "eos_token": self.tokenizer.eos_token,
                "mask_token": self.tokenizer.mask_token,
                "pad_token": self.tokenizer.pad_token,
                "unk_token": self.tokenizer.unk_token,
                "sep_token": self.tokenizer.sep_token,
            }
        )
        if (
            tk.model_max_length is None
            or tk.model_max_length > 1e8
            and hasattr(self.tokenizer, "model_max_length")
        ):
            tk.model_max_length = self.tokenizer.model_max_length
            setattr(
                tk,
                "model_max_length",
                getattr(self.tokenizer, "model_max_length"),
            )
        return tk

    def push_to_hub(
        self,
        repo_id: str,
        use_temp_dir: Optional[bool] = None,
        commit_message: Optional[str] = None,
        private: Optional[bool] = None,
        token: Optional[Union[bool, str]] = None,
        max_shard_size: Optional[Union[int, str]] = "10GB",
        create_pr: bool = False,
        safe_serialization: bool = False,
        **deprecated_kwargs,
    ) -> str:
        """
        Upload the tokenizer to the 🤗 Model Hub.

        Args:
            repo_id: The name of the repository you want to push your {object} to. It should contain your organization name
                when pushing to a given organization.
            use_temp_dir: Whether or not to use a temporary directory to store the files saved before they are pushed to the Hub.
                Will default to `True` if there is no directory named like `repo_id`, `False` otherwise.
            commit_message: Message to commit while pushing. Will default to `"Upload {object}"`.
            private: Whether or not the repository created should be private.
            token: The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
                when running `huggingface-cli login` (stored in `~/.huggingface`). Will default to `True` if `repo_url`
                is not specified.
            max_shard_size: Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard
                will then be each of size lower than this size. If expressed as a string, needs to be digits followed
                by a unit (like `"5MB"`).
            create_pr: Whether or not to create a PR with the uploaded files or directly commit.
            safe_serialization: Whether or not to convert the model weights in safetensors format for safer serialization.
        """
        use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
        if use_auth_token is not None:
            warnings.warn(
                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.",
                FutureWarning,
            )
            if token is not None:
                raise ValueError(
                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
                )
            token = use_auth_token

        repo_path_or_name = deprecated_kwargs.pop("repo_path_or_name", None)
        if repo_path_or_name is not None:
            # Should use `repo_id` instead of `repo_path_or_name`. When using `repo_path_or_name`, we try to infer
            # repo_id from the folder path, if it exists.
            warnings.warn(
                "The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use "
                "`repo_id` instead.",
                FutureWarning,
            )
            if repo_id is not None:
                raise ValueError(
                    "`repo_id` and `repo_path_or_name` are both specified. Please set only the argument `repo_id`."
                )
            if os.path.isdir(repo_path_or_name):
                # repo_path: infer repo_id from the path
                repo_id = repo_id.split(os.path.sep)[-1]
                working_dir = repo_id
            else:
                # repo_name: use it as repo_id
                repo_id = repo_path_or_name
                working_dir = repo_id.split("/")[-1]
        else:
            # Repo_id is passed correctly: infer working_dir from it
            working_dir = repo_id.split("/")[-1]

        # Deprecation warning will be sent after for repo_url and organization
        repo_url = deprecated_kwargs.pop("repo_url", None)
        organization = deprecated_kwargs.pop("organization", None)

        repo_id = self._create_repo(
            repo_id, private, token, repo_url=repo_url, organization=organization
        )

        if use_temp_dir is None:
            use_temp_dir = not os.path.isdir(working_dir)

        with working_or_temp_dir(working_dir=working_dir, use_temp_dir=use_temp_dir) as work_dir:
            files_timestamps = self._get_files_timestamps(work_dir)

            # Save all files.
            with contextlib.suppress(Exception):
                self.save_pretrained(
                    work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization
                )

            self.save(os.path.join(work_dir, self.vocab_files_names))

            return self._upload_modified_files(
                work_dir,
                repo_id,
                files_timestamps,
                commit_message=commit_message,
                token=token,
                create_pr=create_pr,
            )

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Union[str, os.PathLike],
        cache_dir: Optional[Union[str, os.PathLike]] = None,
        force_download: bool = False,
        local_files_only: bool = False,
        token: Optional[Union[str, bool]] = None,
        return_fast_tokenizer: Optional[bool] = False,
        proxies: Optional[Dict[str, str]] = None,
        **kwargs,
    ):
        r"""
        Instantiate a [`~tokenization_utils_base.PreTrainedTokenizerBase`] (or a derived class) from a predefined
        tokenizer.

        Args:
            pretrained_model_name_or_path:
                Can be either:

                - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
                  Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
                  user or organization name, like `dbmdz/bert-base-german-cased`.
                - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved
                  using the [`~tokenization_utils_base.PreTrainedTokenizerBase.save_pretrained`] method, e.g.,
                  `./my_model_directory/`.
                - (**Deprecated**, not applicable to all derived classes) A path or url to a single saved vocabulary
                  file (if and only if the tokenizer only requires a single vocabulary file like Bert or XLNet), e.g.,
                  `./my_model_directory/vocab.txt`.
            cache_dir: Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the
                standard cache should not be used.
            force_download: Whether or not to force the (re-)download the vocabulary files and override the cached versions if they exist.
            proxies: A dictionary of proxy servers to use by protocol or endpoint, e.g.,
                `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
            token: The token to use as HTTP bearer authorization for remote files.
                If `True`, will use the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
            local_files_only: Whether or not to only rely on local files and not to attempt to download any files.
            return_fast_tokenizer: Whether to return fast tokenizer or not.

        Examples:
        ``` py
            # We can't instantiate directly the base class *PreTrainedTokenizerBase* so let's show our examples on a derived class: BertTokenizer
            # Download vocabulary from huggingface.co and cache.
            tokenizer = SAFETokenizer.from_pretrained("datamol-io/safe-gpt")

            # If vocabulary files are in a directory (e.g. tokenizer was saved using *save_pretrained('./test/saved_model/')*)
            tokenizer = SAFETokenizer.from_pretrained("./test/saved_model/")

            # If the tokenizer uses a single vocabulary file, you can point directly to this file
            tokenizer = BertTokenizer.from_pretrained("./test/saved_model/tokenizer.json")
        ```
        """
        resume_download = kwargs.pop("resume_download", False)
        use_auth_token = kwargs.pop("use_auth_token", None)
        subfolder = kwargs.pop("subfolder", None)
        from_pipeline = kwargs.pop("_from_pipeline", None)
        from_auto_class = kwargs.pop("_from_auto", False)
        commit_hash = kwargs.pop("_commit_hash", None)

        if use_auth_token is not None:
            warnings.warn(
                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.",
                FutureWarning,
            )
            if token is not None:
                raise ValueError(
                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
                )
            token = use_auth_token

        user_agent = {
            "file_type": "tokenizer",
            "from_auto_class": from_auto_class,
            "is_fast": "Fast" in cls.__name__,
        }
        if from_pipeline is not None:
            user_agent["using_pipeline"] = from_pipeline

        if is_offline_mode() and not local_files_only:
            logger.info("Offline mode: forcing local_files_only=True")
            local_files_only = True

        pretrained_model_name_or_path = str(pretrained_model_name_or_path)

        os.path.isdir(pretrained_model_name_or_path)
        file_path = None
        if os.path.isfile(pretrained_model_name_or_path):
            file_path = pretrained_model_name_or_path
        elif is_remote_url(pretrained_model_name_or_path):
            file_path = download_url(pretrained_model_name_or_path, proxies=proxies)

        else:
            # EN: remove this when transformers package has uniform API
            cached_file_extra_kwargs = {"use_auth_token": token}
            if packaging.version.parse(transformers_version) >= packaging.version.parse("5.0"):
                cached_file_extra_kwargs = {"token": token}
            # Try to get the tokenizer config to see if there are versioned tokenizer files.
            resolved_vocab_files = cached_file(
                pretrained_model_name_or_path,
                cls.vocab_files_names,
                cache_dir=cache_dir,
                force_download=force_download,
                resume_download=resume_download,
                proxies=proxies,
                local_files_only=local_files_only,
                subfolder=subfolder,
                user_agent=user_agent,
                _raise_exceptions_for_missing_entries=False,
                _raise_exceptions_for_connection_errors=False,
                _commit_hash=commit_hash,
                **cached_file_extra_kwargs,
            )
            commit_hash = extract_commit_hash(resolved_vocab_files, commit_hash)
            file_path = resolved_vocab_files

        if not os.path.isfile(file_path):
            logger.info(
                f"Can't load the following file: {file_path} required for loading the tokenizer"
            )

        tokenizer = cls.load(file_path)
        if return_fast_tokenizer:
            return tokenizer.get_pretrained()
        return tokenizer

bos_token_id property

Get the bos token id

eos_token_id property

Get the bos token id

pad_token_id property

Get the bos token id

__getstate__()

Getting state to allow pickling

Source code in safe/tokenizer.py
195
196
197
198
199
200
201
202
def __getstate__(self):
    """Getting state to allow pickling"""
    with attr_as(self.tokenizer, "pre_tokenizer", Whitespace()):
        d = copy.deepcopy(self.__dict__)
    # copy back tokenizer level attribute
    d["tokenizer_attrs"] = self.tokenizer.__dict__.copy()
    d["tokenizer"].pre_tokenizer = Whitespace()
    return d

__len__()

Gets the count of tokens in vocab along with special tokens.

Source code in safe/tokenizer.py
226
227
228
229
230
def __len__(self):
    r"""
    Gets the count of tokens in vocab along with special tokens.
    """
    return len(self.tokenizer.get_vocab().keys())

__setstate__(d)

Setting state during reloading pickling

Source code in safe/tokenizer.py
204
205
206
207
208
209
210
def __setstate__(self, d):
    """Setting state during reloading pickling"""
    use_pretokenizer = d.get("custom_pre_tokenizer")
    if use_pretokenizer:
        d["tokenizer"].pre_tokenizer = PreTokenizer.custom(SAFESplitter())
    d["tokenizer"].__dict__.update(d.get("tokenizer_attrs", {}))
    self.__dict__.update(d)

decode(ids, skip_special_tokens=True, ignore_stops=False, stop_token_ids=None)

Decodes a list of ids to molecular representation in the format in which this tokenizer was created.

Parameters:

Name Type Description Default
ids list

list of IDs

required
skip_special_tokens bool

whether to skip all special tokens when encountering them

True
ignore_stops bool

whether to ignore the stop tokens, thus decoding till the end

False
stop_token_ids Optional[List[int]]

optional list of stop token ids to use

None

Returns:

Name Type Description
sequence str

str representation of molecule

Source code in safe/tokenizer.py
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
def decode(
    self,
    ids: list,
    skip_special_tokens: bool = True,
    ignore_stops: bool = False,
    stop_token_ids: Optional[List[int]] = None,
) -> str:
    r"""
    Decodes a list of ids to molecular representation in the format in which this tokenizer was created.

    Args:
        ids: list of IDs
        skip_special_tokens: whether to skip all special tokens when encountering them
        ignore_stops: whether to ignore the stop tokens, thus decoding till the end
        stop_token_ids: optional list of stop token ids to use

    Returns:
        sequence: str representation of molecule
    """
    old_id_list = ids
    if not isinstance(ids[0], (list, np.ndarray)) and not torch.is_tensor(ids[0]):
        old_id_list = [ids]
    if not stop_token_ids:
        stop_token_ids = [self.tokenizer.token_to_id(self.tokenizer.eos_token)]

    new_ids_list = []
    for ids in old_id_list:
        new_ids = ids
        if not ignore_stops:
            new_ids = []
            # if first tokens are stop, we just remove it
            # this is because of bart essentially
            pos = 0
            if len(ids) > 1:
                while ids[pos] in stop_token_ids:
                    pos += 1
            # we only ignore when there is a list of tokens
            ids = ids[pos:]
            for pos, id in enumerate(ids):
                if int(id) in stop_token_ids:
                    break
                new_ids.append(id)
        new_ids_list.append(new_ids)
    if len(new_ids_list) == 1:
        return self.tokenizer.decode(
            list(new_ids_list[0]), skip_special_tokens=skip_special_tokens
        )
    return self.tokenizer.decode_batch(
        list(new_ids_list), skip_special_tokens=skip_special_tokens
    )

encode(sample_str, ids_only=True, **kwargs)

Encodes a given molecule string once training is done

Parameters:

Name Type Description Default
sample_str str

Sample string to encode molecule

required
ids_only bool

whether to return only the ids or the encoding objet

True

Returns:

Name Type Description
object list

Returns encoded list of IDs

Source code in safe/tokenizer.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
def encode(self, sample_str: str, ids_only: bool = True, **kwargs) -> list:
    r"""
    Encodes a given molecule string once training is done

    Args:
        sample_str: Sample string to encode molecule
        ids_only: whether to return only the ids or the encoding objet

    Returns:
        object: Returns encoded list of IDs
    """
    if isinstance(sample_str, str):
        enc = self.tokenizer.encode(sample_str, **kwargs)
        if ids_only:
            return enc.ids
        return enc

    encs = self.tokenizer.encode_batch(sample_str, **kwargs)
    if ids_only:
        return [enc.ids for enc in encs]
    return encs

from_dict(data) classmethod

Load tokenizer from dict

Parameters:

Name Type Description Default
data dict

dictionary containing the tokenizer info

required
Source code in safe/tokenizer.py
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
@classmethod
def from_dict(cls, data: dict):
    """Load tokenizer from dict

    Args:
        data: dictionary containing the tokenizer info
    """
    tokenizer_type = data.pop("tokenizer_type", "safe")
    tokenizer_attrs = data.pop("tokenizer_attrs", None)
    custom_pre_tokenizer = data.pop("custom_pre_tokenizer", False)
    tokenizer = Tokenizer.from_str(json.dumps(data))
    if custom_pre_tokenizer:
        tokenizer.pre_tokenizer = PreTokenizer.custom(SAFESplitter())
    mol_tokenizer = cls(tokenizer_type)
    mol_tokenizer.tokenizer = mol_tokenizer.set_special_tokens(tokenizer)
    if tokenizer_attrs and isinstance(tokenizer_attrs, dict):
        mol_tokenizer.tokenizer.__dict__.update(tokenizer_attrs)
    return mol_tokenizer

from_pretrained(pretrained_model_name_or_path, cache_dir=None, force_download=False, local_files_only=False, token=None, return_fast_tokenizer=False, proxies=None, **kwargs) classmethod

Instantiate a [~tokenization_utils_base.PreTrainedTokenizerBase] (or a derived class) from a predefined tokenizer.

Parameters:

Name Type Description Default
pretrained_model_name_or_path Union[str, PathLike]

Can be either:

  • A string, the model id of a predefined tokenizer hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like bert-base-uncased, or namespaced under a user or organization name, like dbmdz/bert-base-german-cased.
  • A path to a directory containing vocabulary files required by the tokenizer, for instance saved using the [~tokenization_utils_base.PreTrainedTokenizerBase.save_pretrained] method, e.g., ./my_model_directory/.
  • (Deprecated, not applicable to all derived classes) A path or url to a single saved vocabulary file (if and only if the tokenizer only requires a single vocabulary file like Bert or XLNet), e.g., ./my_model_directory/vocab.txt.
required
cache_dir Optional[Union[str, PathLike]]

Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used.

None
force_download bool

Whether or not to force the (re-)download the vocabulary files and override the cached versions if they exist.

False
proxies Optional[Dict[str, str]]

A dictionary of proxy servers to use by protocol or endpoint, e.g., {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. The proxies are used on each request.

None
token Optional[Union[str, bool]]

The token to use as HTTP bearer authorization for remote files. If True, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).

None
local_files_only bool

Whether or not to only rely on local files and not to attempt to download any files.

False
return_fast_tokenizer Optional[bool]

Whether to return fast tokenizer or not.

False

Examples:

    # We can't instantiate directly the base class *PreTrainedTokenizerBase* so let's show our examples on a derived class: BertTokenizer
    # Download vocabulary from huggingface.co and cache.
    tokenizer = SAFETokenizer.from_pretrained("datamol-io/safe-gpt")

    # If vocabulary files are in a directory (e.g. tokenizer was saved using *save_pretrained('./test/saved_model/')*)
    tokenizer = SAFETokenizer.from_pretrained("./test/saved_model/")

    # If the tokenizer uses a single vocabulary file, you can point directly to this file
    tokenizer = BertTokenizer.from_pretrained("./test/saved_model/tokenizer.json")
Source code in safe/tokenizer.py
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
@classmethod
def from_pretrained(
    cls,
    pretrained_model_name_or_path: Union[str, os.PathLike],
    cache_dir: Optional[Union[str, os.PathLike]] = None,
    force_download: bool = False,
    local_files_only: bool = False,
    token: Optional[Union[str, bool]] = None,
    return_fast_tokenizer: Optional[bool] = False,
    proxies: Optional[Dict[str, str]] = None,
    **kwargs,
):
    r"""
    Instantiate a [`~tokenization_utils_base.PreTrainedTokenizerBase`] (or a derived class) from a predefined
    tokenizer.

    Args:
        pretrained_model_name_or_path:
            Can be either:

            - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
              Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
              user or organization name, like `dbmdz/bert-base-german-cased`.
            - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved
              using the [`~tokenization_utils_base.PreTrainedTokenizerBase.save_pretrained`] method, e.g.,
              `./my_model_directory/`.
            - (**Deprecated**, not applicable to all derived classes) A path or url to a single saved vocabulary
              file (if and only if the tokenizer only requires a single vocabulary file like Bert or XLNet), e.g.,
              `./my_model_directory/vocab.txt`.
        cache_dir: Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the
            standard cache should not be used.
        force_download: Whether or not to force the (re-)download the vocabulary files and override the cached versions if they exist.
        proxies: A dictionary of proxy servers to use by protocol or endpoint, e.g.,
            `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
        token: The token to use as HTTP bearer authorization for remote files.
            If `True`, will use the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
        local_files_only: Whether or not to only rely on local files and not to attempt to download any files.
        return_fast_tokenizer: Whether to return fast tokenizer or not.

    Examples:
    ``` py
        # We can't instantiate directly the base class *PreTrainedTokenizerBase* so let's show our examples on a derived class: BertTokenizer
        # Download vocabulary from huggingface.co and cache.
        tokenizer = SAFETokenizer.from_pretrained("datamol-io/safe-gpt")

        # If vocabulary files are in a directory (e.g. tokenizer was saved using *save_pretrained('./test/saved_model/')*)
        tokenizer = SAFETokenizer.from_pretrained("./test/saved_model/")

        # If the tokenizer uses a single vocabulary file, you can point directly to this file
        tokenizer = BertTokenizer.from_pretrained("./test/saved_model/tokenizer.json")
    ```
    """
    resume_download = kwargs.pop("resume_download", False)
    use_auth_token = kwargs.pop("use_auth_token", None)
    subfolder = kwargs.pop("subfolder", None)
    from_pipeline = kwargs.pop("_from_pipeline", None)
    from_auto_class = kwargs.pop("_from_auto", False)
    commit_hash = kwargs.pop("_commit_hash", None)

    if use_auth_token is not None:
        warnings.warn(
            "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.",
            FutureWarning,
        )
        if token is not None:
            raise ValueError(
                "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
            )
        token = use_auth_token

    user_agent = {
        "file_type": "tokenizer",
        "from_auto_class": from_auto_class,
        "is_fast": "Fast" in cls.__name__,
    }
    if from_pipeline is not None:
        user_agent["using_pipeline"] = from_pipeline

    if is_offline_mode() and not local_files_only:
        logger.info("Offline mode: forcing local_files_only=True")
        local_files_only = True

    pretrained_model_name_or_path = str(pretrained_model_name_or_path)

    os.path.isdir(pretrained_model_name_or_path)
    file_path = None
    if os.path.isfile(pretrained_model_name_or_path):
        file_path = pretrained_model_name_or_path
    elif is_remote_url(pretrained_model_name_or_path):
        file_path = download_url(pretrained_model_name_or_path, proxies=proxies)

    else:
        # EN: remove this when transformers package has uniform API
        cached_file_extra_kwargs = {"use_auth_token": token}
        if packaging.version.parse(transformers_version) >= packaging.version.parse("5.0"):
            cached_file_extra_kwargs = {"token": token}
        # Try to get the tokenizer config to see if there are versioned tokenizer files.
        resolved_vocab_files = cached_file(
            pretrained_model_name_or_path,
            cls.vocab_files_names,
            cache_dir=cache_dir,
            force_download=force_download,
            resume_download=resume_download,
            proxies=proxies,
            local_files_only=local_files_only,
            subfolder=subfolder,
            user_agent=user_agent,
            _raise_exceptions_for_missing_entries=False,
            _raise_exceptions_for_connection_errors=False,
            _commit_hash=commit_hash,
            **cached_file_extra_kwargs,
        )
        commit_hash = extract_commit_hash(resolved_vocab_files, commit_hash)
        file_path = resolved_vocab_files

    if not os.path.isfile(file_path):
        logger.info(
            f"Can't load the following file: {file_path} required for loading the tokenizer"
        )

    tokenizer = cls.load(file_path)
    if return_fast_tokenizer:
        return tokenizer.get_pretrained()
    return tokenizer

get_pretrained(**kwargs)

Get a pretrained tokenizer from this tokenizer

Returns:

Type Description
PreTrainedTokenizerFast

Returns pre-trained fast tokenizer for hugging face models.

Source code in safe/tokenizer.py
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
def get_pretrained(self, **kwargs) -> PreTrainedTokenizerFast:
    r"""
    Get a pretrained tokenizer from this tokenizer

    Returns:
        Returns pre-trained fast tokenizer for hugging face models.
    """
    with attr_as(self.tokenizer, "pre_tokenizer", Whitespace()):
        tk = PreTrainedTokenizerFast(tokenizer_object=self.tokenizer)
    tk._tokenizer.pre_tokenizer = self.tokenizer.pre_tokenizer
    # now we need to add special_tokens
    tk.add_special_tokens(
        {
            "cls_token": self.tokenizer.cls_token,
            "bos_token": self.tokenizer.bos_token,
            "eos_token": self.tokenizer.eos_token,
            "mask_token": self.tokenizer.mask_token,
            "pad_token": self.tokenizer.pad_token,
            "unk_token": self.tokenizer.unk_token,
            "sep_token": self.tokenizer.sep_token,
        }
    )
    if (
        tk.model_max_length is None
        or tk.model_max_length > 1e8
        and hasattr(self.tokenizer, "model_max_length")
    ):
        tk.model_max_length = self.tokenizer.model_max_length
        setattr(
            tk,
            "model_max_length",
            getattr(self.tokenizer, "model_max_length"),
        )
    return tk

load(file_name) classmethod

Load the current tokenizer from file

Source code in safe/tokenizer.py
304
305
306
307
308
309
310
311
312
@classmethod
def load(cls, file_name):
    """Load the current tokenizer from file"""
    with fsspec.open(file_name, "r") as OUT:
        data_str = OUT.read()
    data = json.loads(data_str)
    # EN: the rust json parser of tokenizers has a predefined structure
    # the next two lines are important
    return cls.from_dict(data)

push_to_hub(repo_id, use_temp_dir=None, commit_message=None, private=None, token=None, max_shard_size='10GB', create_pr=False, safe_serialization=False, **deprecated_kwargs)

Upload the tokenizer to the 🤗 Model Hub.

Parameters:

Name Type Description Default
repo_id str

The name of the repository you want to push your {object} to. It should contain your organization name when pushing to a given organization.

required
use_temp_dir Optional[bool]

Whether or not to use a temporary directory to store the files saved before they are pushed to the Hub. Will default to True if there is no directory named like repo_id, False otherwise.

None
commit_message Optional[str]

Message to commit while pushing. Will default to "Upload {object}".

None
private Optional[bool]

Whether or not the repository created should be private.

None
token Optional[Union[bool, str]]

The token to use as HTTP bearer authorization for remote files. If True, will use the token generated when running huggingface-cli login (stored in ~/.huggingface). Will default to True if repo_url is not specified.

None
max_shard_size Optional[Union[int, str]]

Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size lower than this size. If expressed as a string, needs to be digits followed by a unit (like "5MB").

'10GB'
create_pr bool

Whether or not to create a PR with the uploaded files or directly commit.

False
safe_serialization bool

Whether or not to convert the model weights in safetensors format for safer serialization.

False
Source code in safe/tokenizer.py
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
def push_to_hub(
    self,
    repo_id: str,
    use_temp_dir: Optional[bool] = None,
    commit_message: Optional[str] = None,
    private: Optional[bool] = None,
    token: Optional[Union[bool, str]] = None,
    max_shard_size: Optional[Union[int, str]] = "10GB",
    create_pr: bool = False,
    safe_serialization: bool = False,
    **deprecated_kwargs,
) -> str:
    """
    Upload the tokenizer to the 🤗 Model Hub.

    Args:
        repo_id: The name of the repository you want to push your {object} to. It should contain your organization name
            when pushing to a given organization.
        use_temp_dir: Whether or not to use a temporary directory to store the files saved before they are pushed to the Hub.
            Will default to `True` if there is no directory named like `repo_id`, `False` otherwise.
        commit_message: Message to commit while pushing. Will default to `"Upload {object}"`.
        private: Whether or not the repository created should be private.
        token: The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
            when running `huggingface-cli login` (stored in `~/.huggingface`). Will default to `True` if `repo_url`
            is not specified.
        max_shard_size: Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard
            will then be each of size lower than this size. If expressed as a string, needs to be digits followed
            by a unit (like `"5MB"`).
        create_pr: Whether or not to create a PR with the uploaded files or directly commit.
        safe_serialization: Whether or not to convert the model weights in safetensors format for safer serialization.
    """
    use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
    if use_auth_token is not None:
        warnings.warn(
            "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.",
            FutureWarning,
        )
        if token is not None:
            raise ValueError(
                "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
            )
        token = use_auth_token

    repo_path_or_name = deprecated_kwargs.pop("repo_path_or_name", None)
    if repo_path_or_name is not None:
        # Should use `repo_id` instead of `repo_path_or_name`. When using `repo_path_or_name`, we try to infer
        # repo_id from the folder path, if it exists.
        warnings.warn(
            "The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use "
            "`repo_id` instead.",
            FutureWarning,
        )
        if repo_id is not None:
            raise ValueError(
                "`repo_id` and `repo_path_or_name` are both specified. Please set only the argument `repo_id`."
            )
        if os.path.isdir(repo_path_or_name):
            # repo_path: infer repo_id from the path
            repo_id = repo_id.split(os.path.sep)[-1]
            working_dir = repo_id
        else:
            # repo_name: use it as repo_id
            repo_id = repo_path_or_name
            working_dir = repo_id.split("/")[-1]
    else:
        # Repo_id is passed correctly: infer working_dir from it
        working_dir = repo_id.split("/")[-1]

    # Deprecation warning will be sent after for repo_url and organization
    repo_url = deprecated_kwargs.pop("repo_url", None)
    organization = deprecated_kwargs.pop("organization", None)

    repo_id = self._create_repo(
        repo_id, private, token, repo_url=repo_url, organization=organization
    )

    if use_temp_dir is None:
        use_temp_dir = not os.path.isdir(working_dir)

    with working_or_temp_dir(working_dir=working_dir, use_temp_dir=use_temp_dir) as work_dir:
        files_timestamps = self._get_files_timestamps(work_dir)

        # Save all files.
        with contextlib.suppress(Exception):
            self.save_pretrained(
                work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization
            )

        self.save(os.path.join(work_dir, self.vocab_files_names))

        return self._upload_modified_files(
            work_dir,
            repo_id,
            files_timestamps,
            commit_message=commit_message,
            token=token,
            create_pr=create_pr,
        )

save(file_name=None)

Saves the :class:~tokenizers.Tokenizer to the file at the given path.

Parameters:

Name Type Description Default
file_name str

File where to save tokenizer

None
Source code in safe/tokenizer.py
272
273
274
275
276
277
278
279
280
281
282
283
def save(self, file_name=None):
    r"""
    Saves the :class:`~tokenizers.Tokenizer` to the file at the given path.

    Args:
        file_name (str, optional): File where to save tokenizer
    """
    # EN: whole logic here assumes noone is going to mess with the special token
    tk_data = self.to_dict()
    with fsspec.open(file_name, "w", encoding="utf-8") as OUT:
        out_str = json.dumps(tk_data, ensure_ascii=False)
        OUT.write(out_str)

save_pretrained(*args, **kwargs)

Save pretrained tokenizer

Source code in safe/tokenizer.py
268
269
270
def save_pretrained(self, *args, **kwargs):
    """Save pretrained tokenizer"""
    self.tokenizer.save_pretrained(*args, **kwargs)

set_special_tokens(tokenizer, bos_token=CLS_TOKEN, eos_token=SEP_TOKEN) classmethod

Set special tokens for a tokenizer

Parameters:

Name Type Description Default
tokenizer Tokenizer

tokenizer for which special tokens will be set

required
bos_token str

Optional bos token to use

CLS_TOKEN
eos_token str

Optional eos token to use

SEP_TOKEN
Source code in safe/tokenizer.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
@classmethod
def set_special_tokens(
    cls,
    tokenizer: Tokenizer,
    bos_token: str = CLS_TOKEN,
    eos_token: str = SEP_TOKEN,
):
    """Set special tokens for a tokenizer

    Args:
        tokenizer: tokenizer for which special tokens will be set
        bos_token: Optional bos token to use
        eos_token: Optional eos token to use
    """
    tokenizer.pad_token = PADDING_TOKEN
    tokenizer.cls_token = CLS_TOKEN
    tokenizer.sep_token = SEP_TOKEN
    tokenizer.mask_token = MASK_TOKEN
    tokenizer.unk_token = UNK_TOKEN
    tokenizer.eos_token = eos_token
    tokenizer.bos_token = bos_token

    if isinstance(tokenizer, Tokenizer):
        tokenizer.add_special_tokens(
            [
                PADDING_TOKEN,
                CLS_TOKEN,
                SEP_TOKEN,
                MASK_TOKEN,
                UNK_TOKEN,
                eos_token,
                bos_token,
            ]
        )
    return tokenizer

to_dict(**kwargs)

Convert tokenizer to dict

Source code in safe/tokenizer.py
254
255
256
257
258
259
260
261
262
263
264
265
266
def to_dict(self, **kwargs):
    """Convert tokenizer to dict"""
    # we need to do this because HuggingFace tokenizers doesnt save with custom pre-tokenizers
    if self.splitter is None:
        tk_data = json.loads(self.tokenizer.to_str())
    else:
        with attr_as(self.tokenizer, "pre_tokenizer", Whitespace()):
            # temporary replace pre tokenizer with whitespace
            tk_data = json.loads(self.tokenizer.to_str())
            tk_data["custom_pre_tokenizer"] = True
    tk_data["tokenizer_type"] = self.tokenizer_type
    tk_data["tokenizer_attrs"] = self.tokenizer.__dict__
    return tk_data

train(files, **kwargs)

This is to train a new tokenizer from either a list of file or some input data

Args files (str): file in which your molecules are separated by new line kwargs (dict): optional args for the tokenizer train

Source code in safe/tokenizer.py
183
184
185
186
187
188
189
190
191
192
193
def train(self, files: Optional[List[str]], **kwargs):
    r"""
    This is to train a new tokenizer from either a list of file or some input data

    Args
        files (str): file in which your molecules are separated by new line
        kwargs (dict): optional args for the tokenizer `train`
    """
    if isinstance(files, str):
        files = [files]
    self.tokenizer.train(files=files, trainer=self.trainer)

train_from_iterator(data, **kwargs)

Train the Tokenizer using the provided iterator.

You can provide anything that is a Python Iterator * A list of sequences :obj:List[str] * A generator that yields :obj:str or :obj:List[str] * A Numpy array of strings

Parameters:

Name Type Description Default
data Iterator

data iterator

required
**kwargs Any

additional keyword argument for the tokenizer train_from_iterator

{}
Source code in safe/tokenizer.py
212
213
214
215
216
217
218
219
220
221
222
223
224
def train_from_iterator(self, data: Iterator, **kwargs: Any):
    """Train the Tokenizer using the provided iterator.

    You can provide anything that is a Python Iterator
        * A list of sequences :obj:`List[str]`
        * A generator that yields :obj:`str` or :obj:`List[str]`
        * A Numpy array of strings

    Args:
        data: data iterator
        **kwargs: additional keyword argument for the tokenizer `train_from_iterator`
    """
    self.tokenizer.train_from_iterator(data, trainer=self.trainer, **kwargs)

Utils

MolSlicer

Slice a molecule into head-linker-tail

Source code in safe/utils.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
class MolSlicer:
    """Slice a molecule into head-linker-tail"""

    BOND_SPLITTERS = [
        # two atoms connected by a non ring single bond, one of each is not in a ring and at least two heavy neighbor
        "[R:1]-&!@[!R;!D1:2]",
        # two atoms in different rings linked by a non-ring single bond
        "[R:1]-&!@[R:2]",
    ]
    _BOND_BUFFER = 1  # buffer around substructure match size.
    MAX_CUTS = 2  # maximum number of cuts. Here we need two cuts for head-linker-tail.

    _MERGING_RXN = dm.reactions.rxn_from_smarts(
        "[#0][*:1].[#0][*:4].([#0][*:2].[#0][*:3])>>([*:1][*:2].[*:3][*:4])"
    )

    def __init__(
        self,
        shortest_linker: bool = False,
        min_linker_size: int = 0,
        require_ring_system: bool = True,
        verbose: bool = False,
    ):
        """
        Constructor of bond slicer.

        Args:
            shortest_linker: whether to consider longuest or shortest linker.
                Does not have any effect when expected_head group is provided during splitting
            min_linker_size: minimum linker size
            require_ring_system: whether all fragment needs to have a ring system
            verbose: whether to allow verbosity in logging
        """

        self.bond_splitters = [dm.from_smarts(x) for x in self.BOND_SPLITTERS]
        self.shortest_linker = shortest_linker
        self.min_linker_size = min_linker_size
        self.require_ring_system = require_ring_system
        self.verbose = verbose

    def get_ring_system(self, mol: dm.Mol):
        """Get the list of ring system from a molecule

        Args:
            mol: input molecule for which we are computing the ring system
        """
        mol.UpdatePropertyCache()
        ri = mol.GetRingInfo()
        systems = []
        for ring in ri.AtomRings():
            ring_atoms = set(ring)
            cur_system = []  # keep a track of ring system
            for system in systems:
                if len(ring_atoms.intersection(system)) > 0:
                    ring_atoms = ring_atoms.union(system)  # merge ring system that overlap
                else:
                    cur_system.append(system)
            cur_system.append(ring_atoms)
            systems = cur_system
        return systems

    def _bond_selection_from_max_cuts(self, bond_list: List[int], dist_mat: np.ndarray):
        """Select bonds based on maximum number of cuts allowed"""
        # for now we are just implementing to 2 max cuts algorithms
        if self.MAX_CUTS != 2:
            raise ValueError(f"Only MAX_CUTS=2 is supported, got {self.MAX_CUTS}")

        bond_pdist = np.full((len(bond_list), len(bond_list)), -1)
        for i in range(len(bond_list)):
            for j in range(i, len(bond_list)):
                # we get the minimum topological distance between bond to cut
                bond_pdist[i, j] = bond_pdist[j, i] = min(
                    [dist_mat[a1, a2] for a1, a2 in itertools.product(bond_list[i], bond_list[j])]
                )

        masked_bond_pdist = np.ma.masked_less_equal(bond_pdist, self.min_linker_size)

        if self.shortest_linker:
            return np.unravel_index(np.ma.argmin(masked_bond_pdist), bond_pdist.shape)
        return np.unravel_index(np.ma.argmax(masked_bond_pdist), bond_pdist.shape)

    def _get_bonds_to_cut(self, mol: dm.Mol):
        """Get possible bond to cuts

        Args:
            mol: input molecule
        """
        # use this if you want to enumerate yourself the possible cuts

        ring_systems = self.get_ring_system(mol)
        candidate_bonds = []
        ring_query = Chem.rdqueries.IsInRingQueryAtom()

        for query in self.bond_splitters:
            bonds = mol.GetSubstructMatches(query, uniquify=True)
            cur_unique_bonds = [set(cbond) for cbond in candidate_bonds]
            # do not accept bonds part of the same ring system or already known
            for b in bonds:
                bond_id = mol.GetBondBetweenAtoms(*b).GetIdx()
                bond_cut = Chem.GetMolFrags(
                    Chem.FragmentOnBonds(mol, [bond_id], addDummies=False), asMols=True
                )
                can_add = not self.require_ring_system or all(
                    len(frag.GetAtomsMatchingQuery(ring_query)) > 0 for frag in bond_cut
                )
                if can_add and not (
                    set(b) in cur_unique_bonds or any(x.issuperset(set(b)) for x in ring_systems)
                ):
                    candidate_bonds.append(b)
        return candidate_bonds

    def _fragment_mol(self, mol: dm.Mol, bonds: List[dm.Bond]):
        """Fragment molecules on bonds and return head, linker, tail combination

        Args:
            mol: input molecule
            bonds: list of bonds to cut
        """
        tmp = Chem.rdmolops.FragmentOnBonds(mol, [b.GetIdx() for b in bonds])
        _frags = list(Chem.GetMolFrags(tmp, asMols=True))
        # linker is the one with 2 dummy atoms
        linker_pos = 0
        for pos, _frag in enumerate(_frags):
            if sum([at.GetSymbol() == "*" for at in _frag.GetAtoms()]) == 2:
                linker_pos = pos
                break
        linker = _frags.pop(linker_pos)
        head, tail = _frags
        return (head, linker, tail)

    def _compute_linker_score(self, linker: dm.Mol):
        """Compute the score of a linker to help select between linkers"""

        # we need to take into account
        # case where we require the linker to have a ring system
        # case where we want the linker to be longuest or shortest

        # find shortest path
        attach1, attach2, *_ = [at.GetIdx() for at in linker.GetAtoms() if at.GetSymbol() == "*"]
        score = len(Chem.rdmolops.GetShortestPath(linker, attach1, attach2))
        ring_query = Chem.rdqueries.IsInRingQueryAtom()
        linker_ring_count = len(linker.GetAtomsMatchingQuery(ring_query))
        if self.require_ring_system:
            score *= int(linker_ring_count > 0)
        if score == 0:
            return float("inf")
        if not self.shortest_linker:
            score = 1 / score
        return score

    def __call__(self, mol: Union[dm.Mol, str], expected_head: Union[dm.Mol, str] = None):
        """Perform slicing of the input molecule

        Args:
            mol: input molecule
            expected_head: substructure that should be part of the head.
                The small fragment containing this substructure would be kept as head
        """

        mol = dm.to_mol(mol)
        # remove salt and solution
        mol = dm.keep_largest_fragment(mol)
        Chem.rdDepictor.Compute2DCoords(mol)
        dist_mat = Chem.rdmolops.GetDistanceMatrix(mol)

        if expected_head is not None:
            if isinstance(expected_head, str):
                expected_head = dm.to_mol(expected_head)
            if not mol.HasSubstructMatch(expected_head):
                if self.verbose:
                    logger.info(
                        "Expected head was provided, but does not match molecules. It will be ignored"
                    )
                expected_head = None

        candidate_bonds = self._get_bonds_to_cut(mol)

        # we have all the candidate bonds we can cut
        # now we need to pick the most plausible bonds
        selected_bonds = [mol.GetBondBetweenAtoms(a1, a2) for (a1, a2) in candidate_bonds]

        # CASE 1: no bond to cut ==> only head
        if len(selected_bonds) == 0:
            return (mol, None, None)

        # CASE 2: only one bond ==> linker is empty
        if len(selected_bonds) == 1:
            # there is not linker
            tmp = Chem.rdmolops.FragmentOnBonds(mol, [b.GetIdx() for b in selected_bonds])
            head, tail = Chem.GetMolFrags(tmp, asMols=True)
            return (head, None, tail)

        # CASE 3a: we select the most plausible bond to cut on ourselves
        if expected_head is None:
            choice = self._bond_selection_from_max_cuts(candidate_bonds, dist_mat)
            selected_bonds = [selected_bonds[c] for c in choice]
            return self._fragment_mol(mol, selected_bonds)

        # CASE 3b: slightly more complex case where we want the head to be the smallest graph containing the
        # provided substructure
        bond_combination = list(itertools.combinations(selected_bonds, self.MAX_CUTS))
        bond_score = float("inf")
        linker_score = float("inf")
        head, linker, tail = (None, None, None)
        for split_bonds in bond_combination:
            cur_head, cur_linker, cur_tail = self._fragment_mol(mol, split_bonds)
            # head can also be tail
            head_match = cur_head.GetSubstructMatch(expected_head)
            tail_match = cur_tail.GetSubstructMatch(expected_head)
            if not head_match and not tail_match:
                continue
            if not head_match and tail_match:
                cur_head, cur_tail = cur_tail, cur_head
            cur_bond_score = cur_head.GetNumHeavyAtoms()
            # compute linker score
            cur_linker_score = self._compute_linker_score(cur_linker)
            if (cur_bond_score < bond_score) or (
                cur_bond_score < self._BOND_BUFFER + bond_score and cur_linker_score < linker_score
            ):
                head, linker, tail = cur_head, cur_linker, cur_tail
                bond_score = cur_bond_score
                linker_score = cur_linker_score

        return (head, linker, tail)

    @classmethod
    def link_fragments(
        cls, linker: Union[dm.Mol, str], head: Union[dm.Mol, str], tail: Union[dm.Mol, str]
    ):
        """Link fragments together using the provided linker

        Args:
            linker: linker to use
            head: head fragment
            tail: tail fragment
        """
        if isinstance(linker, dm.Mol):
            linker = dm.to_smiles(linker)
        linker = standardize_attach(linker)
        reactants = [dm.to_mol(head), dm.to_mol(tail), dm.to_mol(linker)]
        return dm.reactions.apply_reaction(
            cls._MERGING_RXN, reactants, as_smiles=True, sanitize=True, product_index=0
        )

__call__(mol, expected_head=None)

Perform slicing of the input molecule

Parameters:

Name Type Description Default
mol Union[Mol, str]

input molecule

required
expected_head Union[Mol, str]

substructure that should be part of the head. The small fragment containing this substructure would be kept as head

None
Source code in safe/utils.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def __call__(self, mol: Union[dm.Mol, str], expected_head: Union[dm.Mol, str] = None):
    """Perform slicing of the input molecule

    Args:
        mol: input molecule
        expected_head: substructure that should be part of the head.
            The small fragment containing this substructure would be kept as head
    """

    mol = dm.to_mol(mol)
    # remove salt and solution
    mol = dm.keep_largest_fragment(mol)
    Chem.rdDepictor.Compute2DCoords(mol)
    dist_mat = Chem.rdmolops.GetDistanceMatrix(mol)

    if expected_head is not None:
        if isinstance(expected_head, str):
            expected_head = dm.to_mol(expected_head)
        if not mol.HasSubstructMatch(expected_head):
            if self.verbose:
                logger.info(
                    "Expected head was provided, but does not match molecules. It will be ignored"
                )
            expected_head = None

    candidate_bonds = self._get_bonds_to_cut(mol)

    # we have all the candidate bonds we can cut
    # now we need to pick the most plausible bonds
    selected_bonds = [mol.GetBondBetweenAtoms(a1, a2) for (a1, a2) in candidate_bonds]

    # CASE 1: no bond to cut ==> only head
    if len(selected_bonds) == 0:
        return (mol, None, None)

    # CASE 2: only one bond ==> linker is empty
    if len(selected_bonds) == 1:
        # there is not linker
        tmp = Chem.rdmolops.FragmentOnBonds(mol, [b.GetIdx() for b in selected_bonds])
        head, tail = Chem.GetMolFrags(tmp, asMols=True)
        return (head, None, tail)

    # CASE 3a: we select the most plausible bond to cut on ourselves
    if expected_head is None:
        choice = self._bond_selection_from_max_cuts(candidate_bonds, dist_mat)
        selected_bonds = [selected_bonds[c] for c in choice]
        return self._fragment_mol(mol, selected_bonds)

    # CASE 3b: slightly more complex case where we want the head to be the smallest graph containing the
    # provided substructure
    bond_combination = list(itertools.combinations(selected_bonds, self.MAX_CUTS))
    bond_score = float("inf")
    linker_score = float("inf")
    head, linker, tail = (None, None, None)
    for split_bonds in bond_combination:
        cur_head, cur_linker, cur_tail = self._fragment_mol(mol, split_bonds)
        # head can also be tail
        head_match = cur_head.GetSubstructMatch(expected_head)
        tail_match = cur_tail.GetSubstructMatch(expected_head)
        if not head_match and not tail_match:
            continue
        if not head_match and tail_match:
            cur_head, cur_tail = cur_tail, cur_head
        cur_bond_score = cur_head.GetNumHeavyAtoms()
        # compute linker score
        cur_linker_score = self._compute_linker_score(cur_linker)
        if (cur_bond_score < bond_score) or (
            cur_bond_score < self._BOND_BUFFER + bond_score and cur_linker_score < linker_score
        ):
            head, linker, tail = cur_head, cur_linker, cur_tail
            bond_score = cur_bond_score
            linker_score = cur_linker_score

    return (head, linker, tail)

__init__(shortest_linker=False, min_linker_size=0, require_ring_system=True, verbose=False)

Constructor of bond slicer.

Parameters:

Name Type Description Default
shortest_linker bool

whether to consider longuest or shortest linker. Does not have any effect when expected_head group is provided during splitting

False
min_linker_size int

minimum linker size

0
require_ring_system bool

whether all fragment needs to have a ring system

True
verbose bool

whether to allow verbosity in logging

False
Source code in safe/utils.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def __init__(
    self,
    shortest_linker: bool = False,
    min_linker_size: int = 0,
    require_ring_system: bool = True,
    verbose: bool = False,
):
    """
    Constructor of bond slicer.

    Args:
        shortest_linker: whether to consider longuest or shortest linker.
            Does not have any effect when expected_head group is provided during splitting
        min_linker_size: minimum linker size
        require_ring_system: whether all fragment needs to have a ring system
        verbose: whether to allow verbosity in logging
    """

    self.bond_splitters = [dm.from_smarts(x) for x in self.BOND_SPLITTERS]
    self.shortest_linker = shortest_linker
    self.min_linker_size = min_linker_size
    self.require_ring_system = require_ring_system
    self.verbose = verbose

get_ring_system(mol)

Get the list of ring system from a molecule

Parameters:

Name Type Description Default
mol Mol

input molecule for which we are computing the ring system

required
Source code in safe/utils.py
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def get_ring_system(self, mol: dm.Mol):
    """Get the list of ring system from a molecule

    Args:
        mol: input molecule for which we are computing the ring system
    """
    mol.UpdatePropertyCache()
    ri = mol.GetRingInfo()
    systems = []
    for ring in ri.AtomRings():
        ring_atoms = set(ring)
        cur_system = []  # keep a track of ring system
        for system in systems:
            if len(ring_atoms.intersection(system)) > 0:
                ring_atoms = ring_atoms.union(system)  # merge ring system that overlap
            else:
                cur_system.append(system)
        cur_system.append(ring_atoms)
        systems = cur_system
    return systems

Link fragments together using the provided linker

Parameters:

Name Type Description Default
linker Union[Mol, str]

linker to use

required
head Union[Mol, str]

head fragment

required
tail Union[Mol, str]

tail fragment

required
Source code in safe/utils.py
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
@classmethod
def link_fragments(
    cls, linker: Union[dm.Mol, str], head: Union[dm.Mol, str], tail: Union[dm.Mol, str]
):
    """Link fragments together using the provided linker

    Args:
        linker: linker to use
        head: head fragment
        tail: tail fragment
    """
    if isinstance(linker, dm.Mol):
        linker = dm.to_smiles(linker)
    linker = standardize_attach(linker)
    reactants = [dm.to_mol(head), dm.to_mol(tail), dm.to_mol(linker)]
    return dm.reactions.apply_reaction(
        cls._MERGING_RXN, reactants, as_smiles=True, sanitize=True, product_index=0
    )

attr_as(obj, field, value)

Temporary replace the value of an object

Parameters:

Name Type Description Default
obj Any

object to temporary patch

required
field str

name of the key to change

required
value Any

value of key to be temporary changed

required
Source code in safe/utils.py
282
283
284
285
286
287
288
289
290
291
292
293
294
295
@contextmanager
def attr_as(obj: Any, field: str, value: Any):
    """Temporary replace the value of an object

    Args:
        obj: object to temporary patch
        field: name of the key to change
        value: value of key to be temporary changed
    """
    old_value = getattr(obj, field, None)
    setattr(obj, field, value)
    yield
    with suppress(TypeError):
        setattr(obj, field, old_value)

compute_side_chains(mol, core, label_by_index=False)

Compute the side chain of a molecule given a core

Finding the side chains

The algorithm to find the side chains from core assumes that the core we get as input has attachment points. Those attachment points are never considered as part of the query, rather they are used to define the attachment points on the side chains. Removing the attachment points from the core is exactly the same as keeping them.

mol = "CC1=C(C(=NO1)C2=CC=CC=C2Cl)C(=O)NC3C4N(C3=O)C(C(S4)(C)C)C(=O)O"
core0 = "CC1(C)CN2C(CC2=O)S1"
core1 = "CC1(C)SC2C(-*)C(=O)N2C1-*"
core2 = "CC1N2C(SC1(C)C)C(N)C2=O"
side_chain = compute_side_chain(core=core0, mol=mol)
dm.to_image([side_chain, core0, mol])
Therefore on the above, core0 and core1 are equivalent for the molecule mol, but core2 is not.

Parameters:

Name Type Description Default
mol Mol

molecule to split

required
core Mol

core to use for deriving the side chains

required
Source code in safe/utils.py
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
def compute_side_chains(mol: dm.Mol, core: dm.Mol, label_by_index: bool = False):
    """Compute the side chain of a molecule given a core

    !!! note "Finding the side chains"
        The algorithm to find the side chains from core assumes that the core we get as input has attachment points.
        Those attachment points are never considered as part of the query, rather they are used to define the attachment points
        on the side chains. Removing the attachment points from the core is exactly the same as keeping them.

        ```python
        mol = "CC1=C(C(=NO1)C2=CC=CC=C2Cl)C(=O)NC3C4N(C3=O)C(C(S4)(C)C)C(=O)O"
        core0 = "CC1(C)CN2C(CC2=O)S1"
        core1 = "CC1(C)SC2C(-*)C(=O)N2C1-*"
        core2 = "CC1N2C(SC1(C)C)C(N)C2=O"
        side_chain = compute_side_chain(core=core0, mol=mol)
        dm.to_image([side_chain, core0, mol])
        ```
        Therefore on the above, core0 and core1 are equivalent for the molecule `mol`, but core2 is not.

    Args:
        mol: molecule to split
        core: core to use for deriving the side chains
    """

    if isinstance(mol, str):
        mol = dm.to_mol(mol)
    if isinstance(core, str):
        core = dm.to_mol(core)
    core_query_param = AdjustQueryParameters()
    core_query_param.makeDummiesQueries = True
    core_query_param.adjustDegree = False
    core_query_param.aromatizeIfPossible = True
    core_query_param.makeBondsGeneric = False
    core_query = AdjustQueryProperties(core, core_query_param)
    return ReplaceCore(
        mol, core_query, labelByIndex=label_by_index, replaceDummies=False, requireDummyMatch=False
    )

convert_to_safe(mol, canonical=False, randomize=False, seed=1, slicer='brics', split_fragment=True, fraction_hs=None, resolution=0.5)

Convert a molecule to a safe representation

Parameters:

Name Type Description Default
mol Mol

molecule to convert

required
canonical bool

whether to use canonical encoding

False
randomize bool

whether to randomize the encoding

False
seed Optional[int]

random seed

1
slicer str

the slicer to use for fragmentation

'brics'
split_fragment bool

whether to split fragments

True
fraction_hs bool

proportion of random atom to which we will add explicit hydrogens

None
resolution Optional[float]

resolution for the partitioning algorithm

0.5
seed Optional[int]

random seed

1
Source code in safe/utils.py
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
def convert_to_safe(
    mol: dm.Mol,
    canonical: bool = False,
    randomize: bool = False,
    seed: Optional[int] = 1,
    slicer: str = "brics",
    split_fragment: bool = True,
    fraction_hs: bool = None,
    resolution: Optional[float] = 0.5,
):
    """Convert a molecule to a safe representation

    Args:
        mol: molecule to convert
        canonical: whether to use canonical encoding
        randomize: whether to randomize the encoding
        seed: random seed
        slicer: the slicer to use for fragmentation
        split_fragment: whether to split fragments
        fraction_hs: proportion of random atom to which we will add explicit hydrogens
        resolution: resolution for the partitioning algorithm
        seed: random seed
    """
    x = None
    try:
        x = sf.encode(mol, canonical=canonical, randomize=randomize, slicer=slicer, seed=seed)
    except sf.SAFEFragmentationError:
        if split_fragment:
            if "." in mol:
                return None
            try:
                x = sf.encode(
                    mol,
                    canonical=False,
                    randomize=randomize,
                    seed=seed,
                    slicer=partial(
                        fragment_aware_spliting,
                        fraction_hs=fraction_hs,
                        resolution=resolution,
                        seed=seed,
                    ),
                )
            except (sf.SAFEEncodeError, sf.SAFEFragmentationError):
                # logger.exception(e)
                return x
        # we need to resplit using attachment point but here we are only adding
    except sf.SAFEEncodeError:
        return x
    return x

filter_by_substructure_constraints(sequences, substruct, n_jobs=-1)

Check whether the input substructures are present in each of the molecule in the sequences

Parameters:

Name Type Description Default
sequences List[Union[str, Mol]]

list of molecules to validate

required
substruct Union[str, Mol]

substructure to use as query

required
n_jobs int

number of jobs to use for parallelization

-1
Source code in safe/utils.py
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
def filter_by_substructure_constraints(
    sequences: List[Union[str, dm.Mol]], substruct: Union[str, dm.Mol], n_jobs: int = -1
):
    """Check whether the input substructures are present in each of the molecule in the sequences

    Args:
        sequences: list of molecules to validate
        substruct: substructure to use as query
        n_jobs: number of jobs to use for parallelization

    """

    if isinstance(substruct, str):
        substruct = standardize_attach(substruct)
        substruct = dm.from_smarts(substruct)

    def _check_match(mol):
        with suppress(Exception):
            mol = dm.to_mol(mol)
            return mol.HasSubstructMatch(substruct)
        return False

    matches = dm.parallelized(_check_match, sequences, n_jobs=n_jobs)
    return list(compress(sequences, matches))

find_partition_edges(G, partition)

Find the edges connecting the subgraphs in a given partition of a graph.

Parameters:

Name Type Description Default
G Graph

The original graph.

required
partition list of list of nodes

The partition of the graph where each element is a list of nodes representing a subgraph.

required

Returns:

Name Type Description
list List[Tuple]

A list of edges connecting the subgraphs in the partition.

Source code in safe/utils.py
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
def find_partition_edges(G: nx.Graph, partition: List[List]) -> List[Tuple]:
    """
    Find the edges connecting the subgraphs in a given partition of a graph.

    Args:
        G (networkx.Graph): The original graph.
        partition (list of list of nodes): The partition of the graph where each element is a list of nodes representing a subgraph.

    Returns:
        list: A list of edges connecting the subgraphs in the partition.
    """
    partition_edges = []
    for subgraph1, subgraph2 in combinations(partition, 2):
        edges = nx.edge_boundary(G, subgraph1, subgraph2)
        partition_edges.extend(edges)
    return partition_edges

fragment_aware_spliting(mol, fraction_hs=None, **kwargs)

Custom splitting algorithm for dataset building.

This slicing strategy will cut any bond including bonding with hydrogens However, only one cut per atom is allowed

Parameters:

Name Type Description Default
mol Mol

molecule to split

required
fraction_hs Optional[bool]

proportion of random atom to which we will add explicit hydrogens

None
kwargs Any

additional arguments to pass to the partitioning algorithm

{}
Source code in safe/utils.py
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
def fragment_aware_spliting(mol: dm.Mol, fraction_hs: Optional[bool] = None, **kwargs: Any):
    """Custom splitting algorithm for dataset building.

    This slicing strategy will cut any bond including bonding with hydrogens
    However, only one cut per atom is allowed

    Args:
        mol: molecule to split
        fraction_hs: proportion of random atom to which we will add explicit hydrogens
        kwargs: additional arguments to pass to the partitioning algorithm
    """
    random.seed(kwargs.get("seed", 1))
    mol = dm.to_mol(mol, remove_hs=False)
    mol = _selective_add_hs(mol, fraction_hs=fraction_hs)
    graph = dm.graph.to_graph(mol)
    d = mol_partition(mol, **kwargs)
    q = deque(d)
    partition = q.pop()
    return find_partition_edges(graph, partition)

list_individual_attach_points(mol, depth=None)

List all individual attachement points.

We do not allow multiple attachment points per substitution position.

Parameters:

Name Type Description Default
mol Mol

molecule for which we need to open the attachment points

required
Source code in safe/utils.py
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
def list_individual_attach_points(mol: dm.Mol, depth: Optional[int] = None):
    """List all individual attachement points.

    We do not allow multiple attachment points per substitution position.

    Args:
        mol: molecule for which we need to open the attachment points

    """
    ATTACHING_RXN = ReactionFromSmarts("[*;h;!$([*][#0]):1]>>[*:1][*]")
    mols = [mol]
    curated_prods = set()
    num_attachs = len(mol.GetSubstructMatches(dm.from_smarts("[*;h:1]"), uniquify=True))
    depth = depth or 1
    depth = min(max(depth, 1), num_attachs)
    while depth > 0:
        prods = set()
        for mol in mols:
            mol = dm.to_mol(mol)
            for p in ATTACHING_RXN.RunReactants((mol,)):
                try:
                    m = dm.sanitize_mol(p[0])
                    sm = dm.to_smiles(m, canonical=True)
                    sm = dm.reactions.add_brackets_to_attachment_points(sm)
                    prods.add(dm.reactions.convert_attach_to_isotope(sm, as_smiles=True))
                except Exception as e:
                    logger.error(e)
        curated_prods.update(prods)
        mols = prods
        depth -= 1
    return list(curated_prods)

mol_partition(mol, query=None, seed=None, **kwargs)

Partition a molecule into fragments using a bond query

Parameters:

Name Type Description Default
mol Mol

molecule to split

required
query Optional[Mol]

bond query to use for splitting

None
seed Optional[int]

random seed

None
kwargs Any

additional arguments to pass to the partitioning algorithm

{}
Source code in safe/utils.py
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
@py_random_state("seed")
def mol_partition(
    mol: dm.Mol, query: Optional[dm.Mol] = None, seed: Optional[int] = None, **kwargs: Any
):
    """Partition a molecule into fragments using a bond query

    Args:
        mol: molecule to split
        query: bond query to use for splitting
        seed: random seed
        kwargs: additional arguments to pass to the partitioning algorithm

    """
    resolution = kwargs.get("resolution", 1.0)
    threshold = kwargs.get("threshold", 1e-7)
    weight = kwargs.get("weight", "weight")

    if query is None:
        query = __mmpa_query

    G = dm.graph.to_graph(mol)
    bond_partition = [
        tuple(sorted(match)) for match in mol.GetSubstructMatches(query, uniquify=True)
    ]

    def get_relevant_edges(e1, e2):
        return tuple(sorted([e1, e2])) not in bond_partition

    subgraphs = nx.subgraph_view(G, filter_edge=get_relevant_edges)

    partition = [{u} for u in G.nodes()]
    inner_partition = sorted(nx.connected_components(subgraphs), key=lambda x: min(x))
    mod = nx.algorithms.community.modularity(
        G, inner_partition, resolution=resolution, weight=weight
    )
    is_directed = G.is_directed()
    graph = G.__class__()
    graph.add_nodes_from(G)
    graph.add_weighted_edges_from(G.edges(data=weight, default=1))
    graph = nx.algorithms.community.louvain._gen_graph(graph, inner_partition)
    m = graph.size(weight="weight")
    partition, inner_partition, improvement = nx.algorithms.community.louvain._one_level(
        graph, m, inner_partition, resolution, is_directed, seed
    )
    improvement = True
    while improvement:
        # gh-5901 protect the sets in the yielded list from further manipulation here
        yield [s.copy() for s in partition]
        new_mod = nx.algorithms.community.modularity(
            graph, inner_partition, resolution=resolution, weight="weight"
        )
        if new_mod - mod <= threshold:
            return
        mod = new_mod
        graph = nx.algorithms.community.louvain._gen_graph(graph, inner_partition)
        partition, inner_partition, improvement = nx.algorithms.community.louvain._one_level(
            graph, m, partition, resolution, is_directed, seed
        )

standardize_attach(inputs, standard_attach='[*]')

Standardize the attachment points of a molecule

Parameters:

Name Type Description Default
inputs str

input molecule

required
standard_attach str

standard attachment point to use

'[*]'
Source code in safe/utils.py
571
572
573
574
575
576
577
578
579
580
581
def standardize_attach(inputs: str, standard_attach: str = "[*]"):
    """Standardize the attachment points of a molecule

    Args:
        inputs: input molecule
        standard_attach: standard attachment point to use
    """

    for attach_regex in _SMILES_ATTACHMENT_POINTS:
        inputs = re.sub(attach_regex, standard_attach, inputs)
    return inputs