Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import os.path | |
| import subprocess | |
| import torch | |
| from Bio.PDB import PDBParser | |
| from src import const | |
| from src.visualizer import save_xyz_file | |
| from src.utils import FoundNaNException | |
| from src.datasets import get_one_hot | |
| def generate_linkers(ddpm, data, sample_fn, name, with_pocket=False, offset_idx=0): | |
| chain = node_mask = None | |
| for i in range(5): | |
| try: | |
| chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1) | |
| break | |
| except FoundNaNException: | |
| continue | |
| print('Generated linker') | |
| x = chain[0][:, :, :ddpm.n_dims] | |
| h = chain[0][:, :, ddpm.n_dims:] | |
| # Put the molecule back to the initial orientation | |
| if with_pocket: | |
| com_mask = data['fragment_only_mask'] if ddpm.center_of_mass == 'fragments' else data['anchors'] | |
| else: | |
| com_mask = data['fragment_mask'] if ddpm.center_of_mass == 'fragments' else data['anchors'] | |
| pos_masked = data['positions'] * com_mask | |
| N = com_mask.sum(1, keepdims=True) | |
| mean = torch.sum(pos_masked, dim=1, keepdim=True) / N | |
| x = x + mean * node_mask | |
| if with_pocket: | |
| node_mask[torch.where(data['pocket_mask'])] = 0 | |
| batch_size = len(data['positions']) | |
| names = [f'output_{offset_idx + i + 1}_{name}' for i in range(batch_size)] | |
| save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='') | |
| print('Saved XYZ files') | |
| def try_to_convert_to_sdf(name, num_samples): | |
| out_files = [] | |
| for i in range(num_samples): | |
| out_xyz = f'results/output_{i + 1}_{name}_.xyz' | |
| out_sdf = f'results/output_{i + 1}_{name}_.sdf' | |
| subprocess.run(f'obabel {out_xyz} -O {out_sdf}', shell=True) | |
| if os.path.exists(out_sdf): | |
| out_files.append(out_sdf) | |
| else: | |
| out_files.append(out_xyz) | |
| return out_files | |
| def get_pocket(mol, pdb_path): | |
| struct = PDBParser().get_structure('', pdb_path) | |
| residue_ids = [] | |
| atom_coords = [] | |
| for residue in struct.get_residues(): | |
| resid = residue.get_id()[1] | |
| for atom in residue.get_atoms(): | |
| atom_coords.append(atom.get_coord()) | |
| residue_ids.append(resid) | |
| residue_ids = np.array(residue_ids) | |
| atom_coords = np.array(atom_coords) | |
| mol_atom_coords = mol.GetConformer().GetPositions() | |
| distances = np.linalg.norm(atom_coords[:, None, :] - mol_atom_coords[None, :, :], axis=-1) | |
| contact_residues = np.unique(residue_ids[np.where(distances.min(1) <= 6)[0]]) | |
| pocket_coords_full = [] | |
| pocket_types_full = [] | |
| pocket_coords_bb = [] | |
| pocket_types_bb = [] | |
| for residue in struct.get_residues(): | |
| resid = residue.get_id()[1] | |
| if resid not in contact_residues: | |
| continue | |
| for atom in residue.get_atoms(): | |
| atom_name = atom.get_name() | |
| atom_type = atom.element.upper() | |
| atom_coord = atom.get_coord() | |
| pocket_coords_full.append(atom_coord.tolist()) | |
| pocket_types_full.append(atom_type) | |
| if atom_name in {'N', 'CA', 'C', 'O'}: | |
| pocket_coords_bb.append(atom_coord.tolist()) | |
| pocket_types_bb.append(atom_type) | |
| pocket_pos = [] | |
| pocket_one_hot = [] | |
| pocket_charges = [] | |
| for coord, atom_type in zip(pocket_coords_full, pocket_types_full): | |
| if atom_type not in const.GEOM_ATOM2IDX.keys(): | |
| continue | |
| pocket_pos.append(coord) | |
| pocket_one_hot.append(get_one_hot(atom_type, const.GEOM_ATOM2IDX)) | |
| pocket_charges.append(const.GEOM_CHARGES[atom_type]) | |
| pocket_pos = np.array(pocket_pos) | |
| pocket_one_hot = np.array(pocket_one_hot) | |
| pocket_charges = np.array(pocket_charges) | |
| return pocket_pos, pocket_one_hot, pocket_charges | |