Spaces:
Runtime error
Runtime error
updated code
Browse files- app.py +8 -4
- src/datasets.py +114 -9
- src/egnn.py +48 -13
- src/lightning.py +8 -4
- src/linker_size.py +0 -4
- src/linker_size_lightning.py +6 -1
- src/utils.py +14 -0
app.py
CHANGED
|
@@ -35,12 +35,16 @@ MODELS_METADATA = {
|
|
| 35 |
'path': 'models/geom_difflinker_given_anchors.ckpt',
|
| 36 |
},
|
| 37 |
'pockets_difflinker': {
|
| 38 |
-
'link': 'https://zenodo.org/record/7775568/files/pockets_difflinker_full_no_anchors.ckpt?download=1',
|
| 39 |
-
'path': 'models/pockets_difflinker.ckpt',
|
|
|
|
|
|
|
| 40 |
},
|
| 41 |
'pockets_difflinker_given_anchors': {
|
| 42 |
-
'link': 'https://zenodo.org/record/7775568/files/pockets_difflinker_full.ckpt?download=1',
|
| 43 |
-
'path': 'models/pockets_difflinker_given_anchors.ckpt',
|
|
|
|
|
|
|
| 44 |
},
|
| 45 |
}
|
| 46 |
|
|
|
|
| 35 |
'path': 'models/geom_difflinker_given_anchors.ckpt',
|
| 36 |
},
|
| 37 |
'pockets_difflinker': {
|
| 38 |
+
# 'link': 'https://zenodo.org/record/7775568/files/pockets_difflinker_full_no_anchors.ckpt?download=1',
|
| 39 |
+
# 'path': 'models/pockets_difflinker.ckpt',
|
| 40 |
+
'link': 'https://zenodo.org/records/10988017/files/pockets_difflinker_full_no_anchors_fc_pdb_excluded.ckpt?download=1',
|
| 41 |
+
'path': 'models/pockets_difflinker_full_no_anchors_fc_pdb_excluded.ckpt',
|
| 42 |
},
|
| 43 |
'pockets_difflinker_given_anchors': {
|
| 44 |
+
# 'link': 'https://zenodo.org/record/7775568/files/pockets_difflinker_full.ckpt?download=1',
|
| 45 |
+
# 'path': 'models/pockets_difflinker_given_anchors.ckpt',
|
| 46 |
+
'link': 'https://zenodo.org/records/10988017/files/pockets_difflinker_full_fc_pdb_excluded.ckpt?download=1',
|
| 47 |
+
'path': 'models/pockets_difflinker_full_fc_pdb_excluded.ckpt',
|
| 48 |
},
|
| 49 |
}
|
| 50 |
|
src/datasets.py
CHANGED
|
@@ -148,6 +148,15 @@ class MOADDataset(Dataset):
|
|
| 148 |
total=len(table)
|
| 149 |
)
|
| 150 |
for (_, row), fragments, linker, pocket_data in generator:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
uuid = row['uuid']
|
| 152 |
name = row['molecule']
|
| 153 |
frag_pos, frag_one_hot, frag_charges = parse_molecule(fragments, is_geom=is_geom)
|
|
@@ -212,16 +221,112 @@ class MOADDataset(Dataset):
|
|
| 212 |
|
| 213 |
return data
|
| 214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
@staticmethod
|
| 216 |
-
def
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
|
| 223 |
-
|
| 224 |
-
|
|
|
|
|
|
|
| 225 |
|
| 226 |
|
| 227 |
def collate(batch):
|
|
@@ -231,7 +336,7 @@ def collate(batch):
|
|
| 231 |
# if 'pocket_mask' not in batch[0].keys():
|
| 232 |
# batch = [data for data in batch if data['num_atoms'] <= 50]
|
| 233 |
# else:
|
| 234 |
-
#
|
| 235 |
|
| 236 |
for i, data in enumerate(batch):
|
| 237 |
for key, value in data.items():
|
|
|
|
| 148 |
total=len(table)
|
| 149 |
)
|
| 150 |
for (_, row), fragments, linker, pocket_data in generator:
|
| 151 |
+
pdb = row['molecule_name'].split('_')[0]
|
| 152 |
+
if pdb in {
|
| 153 |
+
'5ou2', '5ou3', '6hay',
|
| 154 |
+
'5mo8', '5mo5', '5mo7', '5ctp', '5cu2', '5cu4', '5mmr', '5mmf',
|
| 155 |
+
'5moe', '3iw7', '4i9n', '3fi2', '3fi3',
|
| 156 |
+
}:
|
| 157 |
+
print(f'Skipping pdb={pdb}')
|
| 158 |
+
continue
|
| 159 |
+
|
| 160 |
uuid = row['uuid']
|
| 161 |
name = row['molecule']
|
| 162 |
frag_pos, frag_one_hot, frag_charges = parse_molecule(fragments, is_geom=is_geom)
|
|
|
|
| 221 |
|
| 222 |
return data
|
| 223 |
|
| 224 |
+
|
| 225 |
+
class OptimisedMOADDataset(MOADDataset):
|
| 226 |
+
# TODO: finish testing
|
| 227 |
+
|
| 228 |
+
def __len__(self):
|
| 229 |
+
return len(self.data['fragmentation_level_data'])
|
| 230 |
+
|
| 231 |
+
def __getitem__(self, item):
|
| 232 |
+
fragmentation_level_data = self.data['fragmentation_level_data'][item]
|
| 233 |
+
protein_level_data = self.data['protein_level_data'][fragmentation_level_data['name']]
|
| 234 |
+
return {
|
| 235 |
+
**fragmentation_level_data,
|
| 236 |
+
**protein_level_data,
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
@staticmethod
|
| 240 |
+
def preprocess(data_path, prefix, pocket_mode, device):
|
| 241 |
+
print('Preprocessing optimised version of the dataset')
|
| 242 |
+
protein_level_data = {}
|
| 243 |
+
fragmentation_level_data = []
|
| 244 |
+
|
| 245 |
+
table_path = os.path.join(data_path, f'{prefix}_table.csv')
|
| 246 |
+
fragments_path = os.path.join(data_path, f'{prefix}_frag.sdf')
|
| 247 |
+
linkers_path = os.path.join(data_path, f'{prefix}_link.sdf')
|
| 248 |
+
pockets_path = os.path.join(data_path, f'{prefix}_pockets.pkl')
|
| 249 |
+
|
| 250 |
+
is_geom = True
|
| 251 |
+
is_multifrag = 'multifrag' in prefix
|
| 252 |
+
|
| 253 |
+
with open(pockets_path, 'rb') as f:
|
| 254 |
+
pockets = pickle.load(f)
|
| 255 |
+
|
| 256 |
+
table = pd.read_csv(table_path)
|
| 257 |
+
generator = tqdm(
|
| 258 |
+
zip(table.iterrows(), read_sdf(fragments_path), read_sdf(linkers_path), pockets),
|
| 259 |
+
total=len(table)
|
| 260 |
+
)
|
| 261 |
+
for (_, row), fragments, linker, pocket_data in generator:
|
| 262 |
+
uuid = row['uuid']
|
| 263 |
+
name = row['molecule']
|
| 264 |
+
frag_pos, frag_one_hot, frag_charges = parse_molecule(fragments, is_geom=is_geom)
|
| 265 |
+
link_pos, link_one_hot, link_charges = parse_molecule(linker, is_geom=is_geom)
|
| 266 |
+
|
| 267 |
+
# Parsing pocket data
|
| 268 |
+
pocket_pos = pocket_data[f'{pocket_mode}_coord']
|
| 269 |
+
pocket_one_hot = []
|
| 270 |
+
pocket_charges = []
|
| 271 |
+
for atom_type in pocket_data[f'{pocket_mode}_types']:
|
| 272 |
+
pocket_one_hot.append(get_one_hot(atom_type, const.GEOM_ATOM2IDX))
|
| 273 |
+
pocket_charges.append(const.GEOM_CHARGES[atom_type])
|
| 274 |
+
pocket_one_hot = np.array(pocket_one_hot)
|
| 275 |
+
pocket_charges = np.array(pocket_charges)
|
| 276 |
+
|
| 277 |
+
positions = np.concatenate([frag_pos, pocket_pos, link_pos], axis=0)
|
| 278 |
+
one_hot = np.concatenate([frag_one_hot, pocket_one_hot, link_one_hot], axis=0)
|
| 279 |
+
charges = np.concatenate([frag_charges, pocket_charges, link_charges], axis=0)
|
| 280 |
+
anchors = np.zeros_like(charges)
|
| 281 |
+
|
| 282 |
+
if is_multifrag:
|
| 283 |
+
for anchor_idx in map(int, row['anchors'].split('-')):
|
| 284 |
+
anchors[anchor_idx] = 1
|
| 285 |
+
else:
|
| 286 |
+
anchors[row['anchor_1']] = 1
|
| 287 |
+
anchors[row['anchor_2']] = 1
|
| 288 |
+
|
| 289 |
+
fragment_only_mask = np.concatenate([
|
| 290 |
+
np.ones_like(frag_charges),
|
| 291 |
+
np.zeros_like(pocket_charges),
|
| 292 |
+
np.zeros_like(link_charges)
|
| 293 |
+
])
|
| 294 |
+
pocket_mask = np.concatenate([
|
| 295 |
+
np.zeros_like(frag_charges),
|
| 296 |
+
np.ones_like(pocket_charges),
|
| 297 |
+
np.zeros_like(link_charges)
|
| 298 |
+
])
|
| 299 |
+
linker_mask = np.concatenate([
|
| 300 |
+
np.zeros_like(frag_charges),
|
| 301 |
+
np.zeros_like(pocket_charges),
|
| 302 |
+
np.ones_like(link_charges)
|
| 303 |
+
])
|
| 304 |
+
fragment_mask = np.concatenate([
|
| 305 |
+
np.ones_like(frag_charges),
|
| 306 |
+
np.ones_like(pocket_charges),
|
| 307 |
+
np.zeros_like(link_charges)
|
| 308 |
+
])
|
| 309 |
+
|
| 310 |
+
fragmentation_level_data.append({
|
| 311 |
+
'uuid': uuid,
|
| 312 |
+
'name': name,
|
| 313 |
+
'anchors': torch.tensor(anchors, dtype=const.TORCH_FLOAT, device=device),
|
| 314 |
+
'fragment_only_mask': torch.tensor(fragment_only_mask, dtype=const.TORCH_FLOAT, device=device),
|
| 315 |
+
'pocket_mask': torch.tensor(pocket_mask, dtype=const.TORCH_FLOAT, device=device),
|
| 316 |
+
'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
|
| 317 |
+
'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
|
| 318 |
+
})
|
| 319 |
+
protein_level_data[name] = {
|
| 320 |
+
'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device),
|
| 321 |
+
'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device),
|
| 322 |
+
'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device),
|
| 323 |
+
'num_atoms': len(positions),
|
| 324 |
+
}
|
| 325 |
|
| 326 |
+
return {
|
| 327 |
+
'fragmentation_level_data': fragmentation_level_data,
|
| 328 |
+
'protein_level_data': protein_level_data,
|
| 329 |
+
}
|
| 330 |
|
| 331 |
|
| 332 |
def collate(batch):
|
|
|
|
| 336 |
# if 'pocket_mask' not in batch[0].keys():
|
| 337 |
# batch = [data for data in batch if data['num_atoms'] <= 50]
|
| 338 |
# else:
|
| 339 |
+
# batch = [data for data in batch if data['num_atoms'] <= 1000]
|
| 340 |
|
| 341 |
for i, data in enumerate(batch):
|
| 342 |
for key, value in data.items():
|
src/egnn.py
CHANGED
|
@@ -315,7 +315,7 @@ class Dynamics(nn.Module):
|
|
| 315 |
self, n_dims, in_node_nf, context_node_nf, hidden_nf=64, device='cpu', activation=nn.SiLU(),
|
| 316 |
n_layers=4, attention=False, condition_time=True, tanh=False, norm_constant=0, inv_sublayers=2,
|
| 317 |
sin_embedding=False, normalization_factor=100, aggregation_method='sum', model='egnn_dynamics',
|
| 318 |
-
normalization=None, centering=False,
|
| 319 |
):
|
| 320 |
super().__init__()
|
| 321 |
self.device = device
|
|
@@ -324,6 +324,7 @@ class Dynamics(nn.Module):
|
|
| 324 |
self.condition_time = condition_time
|
| 325 |
self.model = model
|
| 326 |
self.centering = centering
|
|
|
|
| 327 |
|
| 328 |
in_node_nf = in_node_nf + context_node_nf + condition_time
|
| 329 |
if self.model == 'egnn_dynamics':
|
|
@@ -369,6 +370,8 @@ class Dynamics(nn.Module):
|
|
| 369 |
- context: (B, N, C)
|
| 370 |
"""
|
| 371 |
|
|
|
|
|
|
|
| 372 |
bs, n_nodes = xh.shape[0], xh.shape[1]
|
| 373 |
edges = self.get_edges(n_nodes, bs) # (2, B*N)
|
| 374 |
node_mask = node_mask.view(bs * n_nodes, 1) # (B*N, 1)
|
|
@@ -421,16 +424,6 @@ class Dynamics(nn.Module):
|
|
| 421 |
if self.condition_time:
|
| 422 |
h_final = h_final[:, :-1]
|
| 423 |
|
| 424 |
-
if torch.any(torch.isnan(vel)):
|
| 425 |
-
print('Found NaN values in velocities')
|
| 426 |
-
nan_mask = torch.isnan(vel).float()
|
| 427 |
-
vel = x * nan_mask + torch.nan_to_num(vel) * (1 - nan_mask)
|
| 428 |
-
|
| 429 |
-
if torch.any(torch.isnan(h_final)):
|
| 430 |
-
print('Found NaN values in features')
|
| 431 |
-
nan_mask = torch.isnan(h_final).float()
|
| 432 |
-
h_final = h[:, :h_final.shape[1]] * nan_mask + torch.nan_to_num(h_final) * (1 - nan_mask)
|
| 433 |
-
|
| 434 |
vel = vel.view(bs, n_nodes, -1) # (B, N, 3)
|
| 435 |
h_final = h_final.view(bs, n_nodes, -1) # (B, N, D)
|
| 436 |
node_mask = node_mask.view(bs, n_nodes, 1) # (B, N, 1)
|
|
@@ -477,12 +470,21 @@ class DynamicsWithPockets(Dynamics):
|
|
| 477 |
if linker_mask is not None:
|
| 478 |
linker_mask = linker_mask.view(bs * n_nodes, 1) # (B*N, 1)
|
| 479 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
# Reshaping node features & adding time feature
|
| 481 |
xh = xh.view(bs * n_nodes, -1).clone() * node_mask # (B*N, D)
|
| 482 |
x = xh[:, :self.n_dims].clone() # (B*N, 3)
|
| 483 |
h = xh[:, self.n_dims:].clone() # (B*N, nf)
|
| 484 |
|
| 485 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 486 |
if self.condition_time:
|
| 487 |
if np.prod(t.size()) == 1:
|
| 488 |
# t is the same for all elements in batch.
|
|
@@ -537,7 +539,7 @@ class DynamicsWithPockets(Dynamics):
|
|
| 537 |
return torch.cat([vel, h_final], dim=2)
|
| 538 |
|
| 539 |
@staticmethod
|
| 540 |
-
def
|
| 541 |
node_mask = node_mask.squeeze().bool()
|
| 542 |
batch_adj = (batch_mask[:, None] == batch_mask[None, :])
|
| 543 |
nodes_adj = (node_mask[:, None] & node_mask[None, :])
|
|
@@ -546,3 +548,36 @@ class DynamicsWithPockets(Dynamics):
|
|
| 546 |
adj = batch_adj & nodes_adj & dists_adj & rm_self_loops
|
| 547 |
edges = torch.stack(torch.where(adj))
|
| 548 |
return edges
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
self, n_dims, in_node_nf, context_node_nf, hidden_nf=64, device='cpu', activation=nn.SiLU(),
|
| 316 |
n_layers=4, attention=False, condition_time=True, tanh=False, norm_constant=0, inv_sublayers=2,
|
| 317 |
sin_embedding=False, normalization_factor=100, aggregation_method='sum', model='egnn_dynamics',
|
| 318 |
+
normalization=None, centering=False, graph_type='FC',
|
| 319 |
):
|
| 320 |
super().__init__()
|
| 321 |
self.device = device
|
|
|
|
| 324 |
self.condition_time = condition_time
|
| 325 |
self.model = model
|
| 326 |
self.centering = centering
|
| 327 |
+
self.graph_type = graph_type
|
| 328 |
|
| 329 |
in_node_nf = in_node_nf + context_node_nf + condition_time
|
| 330 |
if self.model == 'egnn_dynamics':
|
|
|
|
| 370 |
- context: (B, N, C)
|
| 371 |
"""
|
| 372 |
|
| 373 |
+
assert self.graph_type == 'FC'
|
| 374 |
+
|
| 375 |
bs, n_nodes = xh.shape[0], xh.shape[1]
|
| 376 |
edges = self.get_edges(n_nodes, bs) # (2, B*N)
|
| 377 |
node_mask = node_mask.view(bs * n_nodes, 1) # (B*N, 1)
|
|
|
|
| 424 |
if self.condition_time:
|
| 425 |
h_final = h_final[:, :-1]
|
| 426 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
vel = vel.view(bs, n_nodes, -1) # (B, N, 3)
|
| 428 |
h_final = h_final.view(bs, n_nodes, -1) # (B, N, D)
|
| 429 |
node_mask = node_mask.view(bs, n_nodes, 1) # (B, N, 1)
|
|
|
|
| 470 |
if linker_mask is not None:
|
| 471 |
linker_mask = linker_mask.view(bs * n_nodes, 1) # (B*N, 1)
|
| 472 |
|
| 473 |
+
fragment_only_mask = context[..., -2].view(bs * n_nodes, 1) # (B*N, 1)
|
| 474 |
+
pocket_only_mask = context[..., -1].view(bs * n_nodes, 1) # (B*N, 1)
|
| 475 |
+
assert torch.all(fragment_only_mask.bool() | pocket_only_mask.bool() | linker_mask.bool() == node_mask.bool())
|
| 476 |
+
|
| 477 |
# Reshaping node features & adding time feature
|
| 478 |
xh = xh.view(bs * n_nodes, -1).clone() * node_mask # (B*N, D)
|
| 479 |
x = xh[:, :self.n_dims].clone() # (B*N, 3)
|
| 480 |
h = xh[:, self.n_dims:].clone() # (B*N, nf)
|
| 481 |
|
| 482 |
+
assert self.graph_type in ['4A', 'FC-4A', 'FC-10A-4A']
|
| 483 |
+
if self.graph_type == '4A' or self.graph_type is None:
|
| 484 |
+
edges = self.get_dist_edges_4A(x, node_mask, edge_mask)
|
| 485 |
+
else:
|
| 486 |
+
edges = self.get_dist_edges(x, node_mask, edge_mask, linker_mask, fragment_only_mask, pocket_only_mask)
|
| 487 |
+
|
| 488 |
if self.condition_time:
|
| 489 |
if np.prod(t.size()) == 1:
|
| 490 |
# t is the same for all elements in batch.
|
|
|
|
| 539 |
return torch.cat([vel, h_final], dim=2)
|
| 540 |
|
| 541 |
@staticmethod
|
| 542 |
+
def get_dist_edges_4A(x, node_mask, batch_mask):
|
| 543 |
node_mask = node_mask.squeeze().bool()
|
| 544 |
batch_adj = (batch_mask[:, None] == batch_mask[None, :])
|
| 545 |
nodes_adj = (node_mask[:, None] & node_mask[None, :])
|
|
|
|
| 548 |
adj = batch_adj & nodes_adj & dists_adj & rm_self_loops
|
| 549 |
edges = torch.stack(torch.where(adj))
|
| 550 |
return edges
|
| 551 |
+
|
| 552 |
+
def get_dist_edges(self, x, node_mask, batch_mask, linker_mask, fragment_only_mask, pocket_only_mask):
|
| 553 |
+
node_mask = node_mask.squeeze().bool()
|
| 554 |
+
linker_mask = linker_mask.squeeze().bool() & node_mask
|
| 555 |
+
fragment_only_mask = fragment_only_mask.squeeze().bool() & node_mask
|
| 556 |
+
pocket_only_mask = pocket_only_mask.squeeze().bool() & node_mask
|
| 557 |
+
ligand_mask = linker_mask | fragment_only_mask
|
| 558 |
+
|
| 559 |
+
# General constrains:
|
| 560 |
+
batch_adj = (batch_mask[:, None] == batch_mask[None, :])
|
| 561 |
+
nodes_adj = (node_mask[:, None] & node_mask[None, :])
|
| 562 |
+
rm_self_loops = ~torch.eye(x.size(0), dtype=torch.bool, device=x.device)
|
| 563 |
+
constraints = batch_adj & nodes_adj & rm_self_loops
|
| 564 |
+
|
| 565 |
+
# Ligand atoms – fully-connected graph
|
| 566 |
+
ligand_adj = (ligand_mask[:, None] & ligand_mask[None, :])
|
| 567 |
+
ligand_interactions = ligand_adj & constraints
|
| 568 |
+
|
| 569 |
+
# Pocket atoms - within 4A
|
| 570 |
+
pocket_adj = (pocket_only_mask[:, None] & pocket_only_mask[None, :])
|
| 571 |
+
pocket_dists_adj = (torch.cdist(x, x) <= 4)
|
| 572 |
+
pocket_interactions = pocket_adj & pocket_dists_adj & constraints
|
| 573 |
+
|
| 574 |
+
# Pocket-ligand atoms - within 10A
|
| 575 |
+
pocket_ligand_cutoff = 4 if self.graph_type == 'FC-4A' else 10
|
| 576 |
+
pocket_ligand_adj = (ligand_mask[:, None] & pocket_only_mask[None, :])
|
| 577 |
+
pocket_ligand_adj = pocket_ligand_adj | (pocket_only_mask[:, None] & ligand_mask[None, :])
|
| 578 |
+
pocket_ligand_dists_adj = (torch.cdist(x, x) <= pocket_ligand_cutoff)
|
| 579 |
+
pocket_ligand_interactions = pocket_ligand_adj & pocket_ligand_dists_adj & constraints
|
| 580 |
+
|
| 581 |
+
adj = ligand_interactions | pocket_interactions | pocket_ligand_interactions
|
| 582 |
+
edges = torch.stack(torch.where(adj))
|
| 583 |
+
return edges
|
src/lightning.py
CHANGED
|
@@ -44,7 +44,7 @@ class DDPM(pl.LightningModule):
|
|
| 44 |
normalize_factors, include_charges, model,
|
| 45 |
data_path, train_data_prefix, val_data_prefix, batch_size, lr, torch_device, test_epochs, n_stability_samples,
|
| 46 |
normalization=None, log_iterations=None, samples_dir=None, data_augmentation=False,
|
| 47 |
-
center_of_mass='fragments', inpainting=False, anchors_context=True,
|
| 48 |
):
|
| 49 |
super(DDPM, self).__init__()
|
| 50 |
|
|
@@ -54,7 +54,7 @@ class DDPM(pl.LightningModule):
|
|
| 54 |
self.val_data_prefix = val_data_prefix
|
| 55 |
self.batch_size = batch_size
|
| 56 |
self.lr = lr
|
| 57 |
-
self.torch_device =
|
| 58 |
self.include_charges = include_charges
|
| 59 |
self.test_epochs = test_epochs
|
| 60 |
self.n_stability_samples = n_stability_samples
|
|
@@ -72,6 +72,9 @@ class DDPM(pl.LightningModule):
|
|
| 72 |
|
| 73 |
self.is_geom = ('geom' in self.train_data_prefix) or ('MOAD' in self.train_data_prefix)
|
| 74 |
|
|
|
|
|
|
|
|
|
|
| 75 |
if type(activation) is str:
|
| 76 |
activation = get_activation(activation)
|
| 77 |
|
|
@@ -80,7 +83,7 @@ class DDPM(pl.LightningModule):
|
|
| 80 |
in_node_nf=in_node_nf,
|
| 81 |
n_dims=n_dims,
|
| 82 |
context_node_nf=context_node_nf,
|
| 83 |
-
device=
|
| 84 |
hidden_nf=hidden_nf,
|
| 85 |
activation=activation,
|
| 86 |
n_layers=n_layers,
|
|
@@ -94,6 +97,7 @@ class DDPM(pl.LightningModule):
|
|
| 94 |
model=model,
|
| 95 |
normalization=normalization,
|
| 96 |
centering=inpainting,
|
|
|
|
| 97 |
)
|
| 98 |
edm_class = InpaintingEDM if inpainting else EDM
|
| 99 |
self.edm = edm_class(
|
|
@@ -424,7 +428,7 @@ class DDPM(pl.LightningModule):
|
|
| 424 |
context = fragment_mask
|
| 425 |
|
| 426 |
# Add information about pocket to the context
|
| 427 |
-
if
|
| 428 |
fragment_pocket_mask = fragment_mask
|
| 429 |
fragment_only_mask = template_data['fragment_only_mask']
|
| 430 |
pocket_only_mask = fragment_pocket_mask - fragment_only_mask
|
|
|
|
| 44 |
normalize_factors, include_charges, model,
|
| 45 |
data_path, train_data_prefix, val_data_prefix, batch_size, lr, torch_device, test_epochs, n_stability_samples,
|
| 46 |
normalization=None, log_iterations=None, samples_dir=None, data_augmentation=False,
|
| 47 |
+
center_of_mass='fragments', inpainting=False, anchors_context=True, graph_type=None,
|
| 48 |
):
|
| 49 |
super(DDPM, self).__init__()
|
| 50 |
|
|
|
|
| 54 |
self.val_data_prefix = val_data_prefix
|
| 55 |
self.batch_size = batch_size
|
| 56 |
self.lr = lr
|
| 57 |
+
self.torch_device = torch_device
|
| 58 |
self.include_charges = include_charges
|
| 59 |
self.test_epochs = test_epochs
|
| 60 |
self.n_stability_samples = n_stability_samples
|
|
|
|
| 72 |
|
| 73 |
self.is_geom = ('geom' in self.train_data_prefix) or ('MOAD' in self.train_data_prefix)
|
| 74 |
|
| 75 |
+
if graph_type is None:
|
| 76 |
+
graph_type = '4A' if '.' in train_data_prefix else 'FC'
|
| 77 |
+
|
| 78 |
if type(activation) is str:
|
| 79 |
activation = get_activation(activation)
|
| 80 |
|
|
|
|
| 83 |
in_node_nf=in_node_nf,
|
| 84 |
n_dims=n_dims,
|
| 85 |
context_node_nf=context_node_nf,
|
| 86 |
+
device=torch_device,
|
| 87 |
hidden_nf=hidden_nf,
|
| 88 |
activation=activation,
|
| 89 |
n_layers=n_layers,
|
|
|
|
| 97 |
model=model,
|
| 98 |
normalization=normalization,
|
| 99 |
centering=inpainting,
|
| 100 |
+
graph_type=graph_type,
|
| 101 |
)
|
| 102 |
edm_class = InpaintingEDM if inpainting else EDM
|
| 103 |
self.edm = edm_class(
|
|
|
|
| 428 |
context = fragment_mask
|
| 429 |
|
| 430 |
# Add information about pocket to the context
|
| 431 |
+
if '.' in self.train_data_prefix:
|
| 432 |
fragment_pocket_mask = fragment_mask
|
| 433 |
fragment_only_mask = template_data['fragment_only_mask']
|
| 434 |
pocket_only_mask = fragment_pocket_mask - fragment_only_mask
|
src/linker_size.py
CHANGED
|
@@ -21,10 +21,6 @@ class DistributionNodes:
|
|
| 21 |
prob = prob/np.sum(prob)
|
| 22 |
|
| 23 |
self.prob = torch.from_numpy(prob).float()
|
| 24 |
-
|
| 25 |
-
entropy = torch.sum(self.prob * torch.log(self.prob + 1e-30))
|
| 26 |
-
print("Entropy of n_nodes: H[N]", entropy.item())
|
| 27 |
-
|
| 28 |
self.m = Categorical(torch.tensor(prob))
|
| 29 |
|
| 30 |
def sample(self, n_samples=1):
|
|
|
|
| 21 |
prob = prob/np.sum(prob)
|
| 22 |
|
| 23 |
self.prob = torch.from_numpy(prob).float()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
self.m = Categorical(torch.tensor(prob))
|
| 25 |
|
| 26 |
def sample(self, n_samples=1):
|
src/linker_size_lightning.py
CHANGED
|
@@ -40,6 +40,7 @@ class SizeClassifier(pl.LightningModule):
|
|
| 40 |
self.lr = lr
|
| 41 |
self.torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 42 |
self.loss_weights = None if loss_weights is None else torch.tensor(loss_weights, device=self.torch_device)
|
|
|
|
| 43 |
self.gnn = SizeGNN(
|
| 44 |
in_node_nf=in_node_nf,
|
| 45 |
hidden_nf=hidden_nf,
|
|
@@ -79,7 +80,7 @@ class SizeClassifier(pl.LightningModule):
|
|
| 79 |
def test_dataloader(self):
|
| 80 |
return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
|
| 81 |
|
| 82 |
-
def forward(self, data, return_loss=True, with_pocket=False):
|
| 83 |
h = data['one_hot']
|
| 84 |
x = data['positions']
|
| 85 |
fragment_mask = data['fragment_only_mask'] if with_pocket else data['fragment_mask']
|
|
@@ -91,6 +92,10 @@ class SizeClassifier(pl.LightningModule):
|
|
| 91 |
x = x * fragment_mask
|
| 92 |
h = h * fragment_mask
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
# Reshaping
|
| 95 |
bs, n_nodes = x.shape[0], x.shape[1]
|
| 96 |
fragment_mask = fragment_mask.view(bs * n_nodes, 1)
|
|
|
|
| 40 |
self.lr = lr
|
| 41 |
self.torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 42 |
self.loss_weights = None if loss_weights is None else torch.tensor(loss_weights, device=self.torch_device)
|
| 43 |
+
self.in_node_nf = in_node_nf
|
| 44 |
self.gnn = SizeGNN(
|
| 45 |
in_node_nf=in_node_nf,
|
| 46 |
hidden_nf=hidden_nf,
|
|
|
|
| 80 |
def test_dataloader(self):
|
| 81 |
return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
|
| 82 |
|
| 83 |
+
def forward(self, data, return_loss=True, with_pocket=False, adjust_shape=False):
|
| 84 |
h = data['one_hot']
|
| 85 |
x = data['positions']
|
| 86 |
fragment_mask = data['fragment_only_mask'] if with_pocket else data['fragment_mask']
|
|
|
|
| 92 |
x = x * fragment_mask
|
| 93 |
h = h * fragment_mask
|
| 94 |
|
| 95 |
+
if h.shape[-1] != self.in_node_nf and adjust_shape:
|
| 96 |
+
assert torch.allclose(h[..., -1], torch.zeros_like(h[..., -1]))
|
| 97 |
+
h = h[..., :-1]
|
| 98 |
+
|
| 99 |
# Reshaping
|
| 100 |
bs, n_nodes = x.shape[0], x.shape[1]
|
| 101 |
fragment_mask = fragment_mask.view(bs * n_nodes, 1)
|
src/utils.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import sys
|
|
|
|
| 2 |
from datetime import datetime
|
| 3 |
|
| 4 |
import torch
|
|
@@ -21,9 +22,11 @@ class Logger(object):
|
|
| 21 |
# you might want to specify some extra behavior here.
|
| 22 |
pass
|
| 23 |
|
|
|
|
| 24 |
def log(*args):
|
| 25 |
print(f'[{datetime.now()}]', *args)
|
| 26 |
|
|
|
|
| 27 |
class EMA:
|
| 28 |
def __init__(self, beta):
|
| 29 |
super().__init__()
|
|
@@ -257,6 +260,17 @@ def disable_rdkit_logging():
|
|
| 257 |
rkrb.DisableLog('rdApp.error')
|
| 258 |
|
| 259 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
class FoundNaNException(Exception):
|
| 261 |
def __init__(self, x, h):
|
| 262 |
x_nan_idx = self.find_nan_idx(x)
|
|
|
|
| 1 |
import sys
|
| 2 |
+
import random
|
| 3 |
from datetime import datetime
|
| 4 |
|
| 5 |
import torch
|
|
|
|
| 22 |
# you might want to specify some extra behavior here.
|
| 23 |
pass
|
| 24 |
|
| 25 |
+
|
| 26 |
def log(*args):
|
| 27 |
print(f'[{datetime.now()}]', *args)
|
| 28 |
|
| 29 |
+
|
| 30 |
class EMA:
|
| 31 |
def __init__(self, beta):
|
| 32 |
super().__init__()
|
|
|
|
| 260 |
rkrb.DisableLog('rdApp.error')
|
| 261 |
|
| 262 |
|
| 263 |
+
def set_deterministic(seed):
|
| 264 |
+
random.seed(seed)
|
| 265 |
+
np.random.seed(seed)
|
| 266 |
+
torch.manual_seed(seed)
|
| 267 |
+
if torch.cuda.is_available():
|
| 268 |
+
torch.cuda.manual_seed_all(seed)
|
| 269 |
+
|
| 270 |
+
torch.backends.cudnn.deterministic = True
|
| 271 |
+
torch.backends.cudnn.benchmark = False
|
| 272 |
+
|
| 273 |
+
|
| 274 |
class FoundNaNException(Exception):
|
| 275 |
def __init__(self, x, h):
|
| 276 |
x_nan_idx = self.find_nan_idx(x)
|