Spaces:
Runtime error
Runtime error
| import math | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from src import utils | |
| from pdb import set_trace | |
| class GCL(nn.Module): | |
| def __init__(self, input_nf, output_nf, hidden_nf, normalization_factor, aggregation_method, activation, | |
| edges_in_d=0, nodes_att_dim=0, attention=False, normalization=None): | |
| super(GCL, self).__init__() | |
| input_edge = input_nf * 2 | |
| self.normalization_factor = normalization_factor | |
| self.aggregation_method = aggregation_method | |
| self.attention = attention | |
| self.edge_mlp = nn.Sequential( | |
| nn.Linear(input_edge + edges_in_d, hidden_nf), | |
| activation, | |
| nn.Linear(hidden_nf, hidden_nf), | |
| activation) | |
| if normalization is None: | |
| self.node_mlp = nn.Sequential( | |
| nn.Linear(hidden_nf + input_nf + nodes_att_dim, hidden_nf), | |
| activation, | |
| nn.Linear(hidden_nf, output_nf) | |
| ) | |
| elif normalization == 'batch_norm': | |
| self.node_mlp = nn.Sequential( | |
| nn.Linear(hidden_nf + input_nf + nodes_att_dim, hidden_nf), | |
| nn.BatchNorm1d(hidden_nf), | |
| activation, | |
| nn.Linear(hidden_nf, output_nf), | |
| nn.BatchNorm1d(output_nf), | |
| ) | |
| else: | |
| raise NotImplementedError | |
| if self.attention: | |
| self.att_mlp = nn.Sequential(nn.Linear(hidden_nf, 1), nn.Sigmoid()) | |
| def edge_model(self, source, target, edge_attr, edge_mask): | |
| if edge_attr is None: # Unused. | |
| out = torch.cat([source, target], dim=1) | |
| else: | |
| out = torch.cat([source, target, edge_attr], dim=1) | |
| mij = self.edge_mlp(out) | |
| if self.attention: | |
| att_val = self.att_mlp(mij) | |
| out = mij * att_val | |
| else: | |
| out = mij | |
| if edge_mask is not None: | |
| out = out * edge_mask | |
| return out, mij | |
| def node_model(self, x, edge_index, edge_attr, node_attr): | |
| row, col = edge_index | |
| agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0), | |
| normalization_factor=self.normalization_factor, | |
| aggregation_method=self.aggregation_method) | |
| if node_attr is not None: | |
| agg = torch.cat([x, agg, node_attr], dim=1) | |
| else: | |
| agg = torch.cat([x, agg], dim=1) | |
| out = x + self.node_mlp(agg) | |
| return out, agg | |
| def forward(self, h, edge_index, edge_attr=None, node_attr=None, node_mask=None, edge_mask=None): | |
| row, col = edge_index | |
| edge_feat, mij = self.edge_model(h[row], h[col], edge_attr, edge_mask) | |
| h, agg = self.node_model(h, edge_index, edge_feat, node_attr) | |
| if node_mask is not None: | |
| h = h * node_mask | |
| return h, mij | |
| class EquivariantUpdate(nn.Module): | |
| def __init__(self, hidden_nf, normalization_factor, aggregation_method, | |
| edges_in_d=1, activation=nn.SiLU(), tanh=False, coords_range=10.0): | |
| super(EquivariantUpdate, self).__init__() | |
| self.tanh = tanh | |
| self.coords_range = coords_range | |
| input_edge = hidden_nf * 2 + edges_in_d | |
| layer = nn.Linear(hidden_nf, 1, bias=False) | |
| torch.nn.init.xavier_uniform_(layer.weight, gain=0.001) | |
| self.coord_mlp = nn.Sequential( | |
| nn.Linear(input_edge, hidden_nf), | |
| activation, | |
| nn.Linear(hidden_nf, hidden_nf), | |
| activation, | |
| layer) | |
| self.normalization_factor = normalization_factor | |
| self.aggregation_method = aggregation_method | |
| def coord_model(self, h, coord, edge_index, coord_diff, edge_attr, edge_mask, linker_mask): | |
| row, col = edge_index | |
| input_tensor = torch.cat([h[row], h[col], edge_attr], dim=1) | |
| if self.tanh: | |
| trans = coord_diff * torch.tanh(self.coord_mlp(input_tensor)) * self.coords_range | |
| else: | |
| trans = coord_diff * self.coord_mlp(input_tensor) | |
| if edge_mask is not None: | |
| trans = trans * edge_mask | |
| agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0), | |
| normalization_factor=self.normalization_factor, | |
| aggregation_method=self.aggregation_method) | |
| if linker_mask is not None: | |
| agg = agg * linker_mask | |
| coord = coord + agg | |
| return coord | |
| def forward( | |
| self, h, coord, edge_index, coord_diff, edge_attr=None, linker_mask=None, node_mask=None, edge_mask=None | |
| ): | |
| coord = self.coord_model(h, coord, edge_index, coord_diff, edge_attr, edge_mask, linker_mask) | |
| if node_mask is not None: | |
| coord = coord * node_mask | |
| return coord | |
| class EquivariantBlock(nn.Module): | |
| def __init__(self, hidden_nf, edge_feat_nf=2, device='cpu', activation=nn.SiLU(), n_layers=2, attention=True, | |
| norm_diff=True, tanh=False, coords_range=15, norm_constant=1, sin_embedding=None, | |
| normalization_factor=100, aggregation_method='sum'): | |
| super(EquivariantBlock, self).__init__() | |
| self.hidden_nf = hidden_nf | |
| self.device = device | |
| self.n_layers = n_layers | |
| self.coords_range_layer = float(coords_range) | |
| self.norm_diff = norm_diff | |
| self.norm_constant = norm_constant | |
| self.sin_embedding = sin_embedding | |
| self.normalization_factor = normalization_factor | |
| self.aggregation_method = aggregation_method | |
| for i in range(0, n_layers): | |
| self.add_module("gcl_%d" % i, GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=edge_feat_nf, | |
| activation=activation, attention=attention, | |
| normalization_factor=self.normalization_factor, | |
| aggregation_method=self.aggregation_method)) | |
| self.add_module("gcl_equiv", EquivariantUpdate(hidden_nf, edges_in_d=edge_feat_nf, activation=activation, tanh=tanh, | |
| coords_range=self.coords_range_layer, | |
| normalization_factor=self.normalization_factor, | |
| aggregation_method=self.aggregation_method)) | |
| self.to(self.device) | |
| def forward(self, h, x, edge_index, node_mask=None, linker_mask=None, edge_mask=None, edge_attr=None): | |
| # Edit Emiel: Remove velocity as input | |
| distances, coord_diff = coord2diff(x, edge_index, self.norm_constant) | |
| if self.sin_embedding is not None: | |
| distances = self.sin_embedding(distances) | |
| edge_attr = torch.cat([distances, edge_attr], dim=1) | |
| for i in range(0, self.n_layers): | |
| h, _ = self._modules["gcl_%d" % i](h, edge_index, edge_attr=edge_attr, node_mask=node_mask, edge_mask=edge_mask) | |
| x = self._modules["gcl_equiv"]( | |
| h, x, | |
| edge_index=edge_index, | |
| coord_diff=coord_diff, | |
| edge_attr=edge_attr, | |
| linker_mask=linker_mask, | |
| node_mask=node_mask, | |
| edge_mask=edge_mask, | |
| ) | |
| # Important, the bias of the last linear might be non-zero | |
| if node_mask is not None: | |
| h = h * node_mask | |
| return h, x | |
| class EGNN(nn.Module): | |
| def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', activation=nn.SiLU(), n_layers=3, attention=False, | |
| norm_diff=True, out_node_nf=None, tanh=False, coords_range=15, norm_constant=1, inv_sublayers=2, | |
| sin_embedding=False, normalization_factor=100, aggregation_method='sum'): | |
| super(EGNN, self).__init__() | |
| if out_node_nf is None: | |
| out_node_nf = in_node_nf | |
| self.hidden_nf = hidden_nf | |
| self.device = device | |
| self.n_layers = n_layers | |
| self.coords_range_layer = float(coords_range/n_layers) | |
| self.norm_diff = norm_diff | |
| self.normalization_factor = normalization_factor | |
| self.aggregation_method = aggregation_method | |
| if sin_embedding: | |
| self.sin_embedding = SinusoidsEmbeddingNew() | |
| edge_feat_nf = self.sin_embedding.dim * 2 | |
| else: | |
| self.sin_embedding = None | |
| edge_feat_nf = 2 | |
| self.embedding = nn.Linear(in_node_nf, self.hidden_nf) | |
| self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf) | |
| for i in range(0, n_layers): | |
| self.add_module("e_block_%d" % i, EquivariantBlock(hidden_nf, edge_feat_nf=edge_feat_nf, device=device, | |
| activation=activation, n_layers=inv_sublayers, | |
| attention=attention, norm_diff=norm_diff, tanh=tanh, | |
| coords_range=coords_range, norm_constant=norm_constant, | |
| sin_embedding=self.sin_embedding, | |
| normalization_factor=self.normalization_factor, | |
| aggregation_method=self.aggregation_method)) | |
| self.to(self.device) | |
| def forward(self, h, x, edge_index, node_mask=None, linker_mask=None, edge_mask=None): | |
| # Edit Emiel: Remove velocity as input | |
| distances, _ = coord2diff(x, edge_index) | |
| if self.sin_embedding is not None: | |
| distances = self.sin_embedding(distances) | |
| h = self.embedding(h) | |
| for i in range(0, self.n_layers): | |
| h, x = self._modules["e_block_%d" % i]( | |
| h, x, edge_index, | |
| node_mask=node_mask, | |
| linker_mask=linker_mask, | |
| edge_mask=edge_mask, | |
| edge_attr=distances | |
| ) | |
| # Important, the bias of the last linear might be non-zero | |
| h = self.embedding_out(h) | |
| if node_mask is not None: | |
| h = h * node_mask | |
| return h, x | |
| class GNN(nn.Module): | |
| def __init__(self, in_node_nf, in_edge_nf, hidden_nf, aggregation_method='sum', device='cpu', | |
| activation=nn.SiLU(), n_layers=4, attention=False, normalization_factor=1, | |
| out_node_nf=None, normalization=None): | |
| super(GNN, self).__init__() | |
| if out_node_nf is None: | |
| out_node_nf = in_node_nf | |
| self.hidden_nf = hidden_nf | |
| self.device = device | |
| self.n_layers = n_layers | |
| # Encoder | |
| self.embedding = nn.Linear(in_node_nf, self.hidden_nf) | |
| self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf) | |
| for i in range(0, n_layers): | |
| self.add_module("gcl_%d" % i, GCL( | |
| self.hidden_nf, self.hidden_nf, self.hidden_nf, | |
| normalization_factor=normalization_factor, | |
| aggregation_method=aggregation_method, | |
| edges_in_d=in_edge_nf, activation=activation, | |
| attention=attention, normalization=normalization)) | |
| self.to(self.device) | |
| def forward(self, h, edges, edge_attr=None, node_mask=None, edge_mask=None): | |
| # Edit Emiel: Remove velocity as input | |
| h = self.embedding(h) | |
| for i in range(0, self.n_layers): | |
| h, _ = self._modules["gcl_%d" % i](h, edges, edge_attr=edge_attr, node_mask=node_mask, edge_mask=edge_mask) | |
| h = self.embedding_out(h) | |
| # Important, the bias of the last linear might be non-zero | |
| if node_mask is not None: | |
| h = h * node_mask | |
| return h | |
| class SinusoidsEmbeddingNew(nn.Module): | |
| def __init__(self, max_res=15., min_res=15. / 2000., div_factor=4): | |
| super().__init__() | |
| self.n_frequencies = int(math.log(max_res / min_res, div_factor)) + 1 | |
| self.frequencies = 2 * math.pi * div_factor ** torch.arange(self.n_frequencies)/max_res | |
| self.dim = len(self.frequencies) * 2 | |
| def forward(self, x): | |
| x = torch.sqrt(x + 1e-8) | |
| emb = x * self.frequencies[None, :].to(x.device) | |
| emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | |
| return emb.detach() | |
| def coord2diff(x, edge_index, norm_constant=1): | |
| row, col = edge_index | |
| coord_diff = x[row] - x[col] | |
| radial = torch.sum((coord_diff) ** 2, 1).unsqueeze(1) | |
| norm = torch.sqrt(radial + 1e-8) | |
| coord_diff = coord_diff/(norm + norm_constant) | |
| return radial, coord_diff | |
| def unsorted_segment_sum(data, segment_ids, num_segments, normalization_factor, aggregation_method: str): | |
| """Custom PyTorch op to replicate TensorFlow's `unsorted_segment_sum`. | |
| Normalization: 'sum' or 'mean'. | |
| """ | |
| result_shape = (num_segments, data.size(1)) | |
| result = data.new_full(result_shape, 0) # Init empty result tensor. | |
| segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1)) | |
| result.scatter_add_(0, segment_ids, data) | |
| if aggregation_method == 'sum': | |
| result = result / normalization_factor | |
| if aggregation_method == 'mean': | |
| norm = data.new_zeros(result.shape) | |
| norm.scatter_add_(0, segment_ids, data.new_ones(data.shape)) | |
| norm[norm == 0] = 1 | |
| result = result / norm | |
| return result | |
| class Dynamics(nn.Module): | |
| def __init__( | |
| self, n_dims, in_node_nf, context_node_nf, hidden_nf=64, device='cpu', activation=nn.SiLU(), | |
| n_layers=4, attention=False, condition_time=True, tanh=False, norm_constant=0, inv_sublayers=2, | |
| sin_embedding=False, normalization_factor=100, aggregation_method='sum', model='egnn_dynamics', | |
| normalization=None, centering=False, graph_type='FC', | |
| ): | |
| super().__init__() | |
| self.device = device | |
| self.n_dims = n_dims | |
| self.context_node_nf = context_node_nf | |
| self.condition_time = condition_time | |
| self.model = model | |
| self.centering = centering | |
| self.graph_type = graph_type | |
| in_node_nf = in_node_nf + context_node_nf + condition_time | |
| if self.model == 'egnn_dynamics': | |
| self.dynamics = EGNN( | |
| in_node_nf=in_node_nf, | |
| in_edge_nf=1, | |
| hidden_nf=hidden_nf, device=device, | |
| activation=activation, | |
| n_layers=n_layers, | |
| attention=attention, | |
| tanh=tanh, | |
| norm_constant=norm_constant, | |
| inv_sublayers=inv_sublayers, | |
| sin_embedding=sin_embedding, | |
| normalization_factor=normalization_factor, | |
| aggregation_method=aggregation_method, | |
| ) | |
| elif self.model == 'gnn_dynamics': | |
| self.dynamics = GNN( | |
| in_node_nf=in_node_nf+3, | |
| in_edge_nf=0, | |
| hidden_nf=hidden_nf, | |
| out_node_nf=in_node_nf+3, | |
| device=device, | |
| activation=activation, | |
| n_layers=n_layers, | |
| attention=attention, | |
| normalization_factor=normalization_factor, | |
| aggregation_method=aggregation_method, | |
| normalization=normalization, | |
| ) | |
| else: | |
| raise NotImplementedError | |
| self.edge_cache = {} | |
| def forward(self, t, xh, node_mask, linker_mask, edge_mask, context): | |
| """ | |
| - t: (B) | |
| - xh: (B, N, D), where D = 3 + nf | |
| - node_mask: (B, N, 1) | |
| - edge_mask: (B*N*N, 1) | |
| - context: (B, N, C) | |
| """ | |
| assert self.graph_type == 'FC' | |
| bs, n_nodes = xh.shape[0], xh.shape[1] | |
| edges = self.get_edges(n_nodes, bs) # (2, B*N) | |
| node_mask = node_mask.view(bs * n_nodes, 1) # (B*N, 1) | |
| if linker_mask is not None: | |
| linker_mask = linker_mask.view(bs * n_nodes, 1) # (B*N, 1) | |
| # Reshaping node features & adding time feature | |
| xh = xh.view(bs * n_nodes, -1).clone() * node_mask # (B*N, D) | |
| x = xh[:, :self.n_dims].clone() # (B*N, 3) | |
| h = xh[:, self.n_dims:].clone() # (B*N, nf) | |
| if self.condition_time: | |
| if np.prod(t.size()) == 1: | |
| # t is the same for all elements in batch. | |
| h_time = torch.empty_like(h[:, 0:1]).fill_(t.item()) | |
| else: | |
| # t is different over the batch dimension. | |
| h_time = t.view(bs, 1).repeat(1, n_nodes) | |
| h_time = h_time.view(bs * n_nodes, 1) | |
| h = torch.cat([h, h_time], dim=1) # (B*N, nf+1) | |
| if context is not None: | |
| context = context.view(bs*n_nodes, self.context_node_nf) | |
| h = torch.cat([h, context], dim=1) | |
| # Forward EGNN | |
| # Output: h_final (B*N, nf), x_final (B*N, 3), vel (B*N, 3) | |
| if self.model == 'egnn_dynamics': | |
| h_final, x_final = self.dynamics( | |
| h, | |
| x, | |
| edges, | |
| node_mask=node_mask, | |
| linker_mask=linker_mask, | |
| edge_mask=edge_mask | |
| ) | |
| vel = (x_final - x) * node_mask # This masking operation is redundant but just in case | |
| elif self.model == 'gnn_dynamics': | |
| xh = torch.cat([x, h], dim=1) | |
| output = self.dynamics(xh, edges, node_mask=node_mask) | |
| vel = output[:, 0:3] * node_mask | |
| h_final = output[:, 3:] | |
| else: | |
| raise NotImplementedError | |
| # Slice off context size | |
| if context is not None: | |
| h_final = h_final[:, :-self.context_node_nf] | |
| # Slice off last dimension which represented time. | |
| if self.condition_time: | |
| h_final = h_final[:, :-1] | |
| vel = vel.view(bs, n_nodes, -1) # (B, N, 3) | |
| h_final = h_final.view(bs, n_nodes, -1) # (B, N, D) | |
| node_mask = node_mask.view(bs, n_nodes, 1) # (B, N, 1) | |
| if self.centering: | |
| vel = utils.remove_mean_with_mask(vel, node_mask) | |
| return torch.cat([vel, h_final], dim=2) | |
| def get_edges(self, n_nodes, batch_size): | |
| if n_nodes in self.edge_cache: | |
| edges_dic_b = self.edge_cache[n_nodes] | |
| if batch_size in edges_dic_b: | |
| return edges_dic_b[batch_size] | |
| else: | |
| # get edges for a single sample | |
| rows, cols = [], [] | |
| for batch_idx in range(batch_size): | |
| for i in range(n_nodes): | |
| for j in range(n_nodes): | |
| rows.append(i + batch_idx * n_nodes) | |
| cols.append(j + batch_idx * n_nodes) | |
| edges = [torch.LongTensor(rows).to(self.device), torch.LongTensor(cols).to(self.device)] | |
| edges_dic_b[batch_size] = edges | |
| return edges | |
| else: | |
| self.edge_cache[n_nodes] = {} | |
| return self.get_edges(n_nodes, batch_size) | |
| class DynamicsWithPockets(Dynamics): | |
| def forward(self, t, xh, node_mask, linker_mask, edge_mask, context): | |
| """ | |
| - t: (B) | |
| - xh: (B, N, D), where D = 3 + nf | |
| - node_mask: (B, N, 1) | |
| - edge_mask: (B*N*N, 1) | |
| - context: (B, N, C) | |
| """ | |
| bs, n_nodes = xh.shape[0], xh.shape[1] | |
| node_mask = node_mask.view(bs * n_nodes, 1) # (B*N, 1) | |
| if linker_mask is not None: | |
| linker_mask = linker_mask.view(bs * n_nodes, 1) # (B*N, 1) | |
| fragment_only_mask = context[..., -2].view(bs * n_nodes, 1) # (B*N, 1) | |
| pocket_only_mask = context[..., -1].view(bs * n_nodes, 1) # (B*N, 1) | |
| assert torch.all(fragment_only_mask.bool() | pocket_only_mask.bool() | linker_mask.bool() == node_mask.bool()) | |
| # Reshaping node features & adding time feature | |
| xh = xh.view(bs * n_nodes, -1).clone() * node_mask # (B*N, D) | |
| x = xh[:, :self.n_dims].clone() # (B*N, 3) | |
| h = xh[:, self.n_dims:].clone() # (B*N, nf) | |
| assert self.graph_type in ['4A', 'FC-4A', 'FC-10A-4A'] | |
| if self.graph_type == '4A' or self.graph_type is None: | |
| edges = self.get_dist_edges_4A(x, node_mask, edge_mask) | |
| else: | |
| edges = self.get_dist_edges(x, node_mask, edge_mask, linker_mask, fragment_only_mask, pocket_only_mask) | |
| if self.condition_time: | |
| if np.prod(t.size()) == 1: | |
| # t is the same for all elements in batch. | |
| h_time = torch.empty_like(h[:, 0:1]).fill_(t.item()) | |
| else: | |
| # t is different over the batch dimension. | |
| h_time = t.view(bs, 1).repeat(1, n_nodes) | |
| h_time = h_time.view(bs * n_nodes, 1) | |
| h = torch.cat([h, h_time], dim=1) # (B*N, nf+1) | |
| if context is not None: | |
| context = context.view(bs*n_nodes, self.context_node_nf) | |
| h = torch.cat([h, context], dim=1) | |
| # Forward EGNN | |
| # Output: h_final (B*N, nf), x_final (B*N, 3), vel (B*N, 3) | |
| if self.model == 'egnn_dynamics': | |
| h_final, x_final = self.dynamics( | |
| h, | |
| x, | |
| edges, | |
| node_mask=node_mask, | |
| linker_mask=linker_mask, | |
| edge_mask=None | |
| ) | |
| vel = (x_final - x) * node_mask # This masking operation is redundant but just in case | |
| elif self.model == 'gnn_dynamics': | |
| xh = torch.cat([x, h], dim=1) | |
| output = self.dynamics(xh, edges, node_mask=node_mask) | |
| vel = output[:, 0:3] * node_mask | |
| h_final = output[:, 3:] | |
| else: | |
| raise NotImplementedError | |
| # Slice off context size | |
| if context is not None: | |
| h_final = h_final[:, :-self.context_node_nf] | |
| # Slice off last dimension which represented time. | |
| if self.condition_time: | |
| h_final = h_final[:, :-1] | |
| vel = vel.view(bs, n_nodes, -1) # (B, N, 3) | |
| h_final = h_final.view(bs, n_nodes, -1) # (B, N, D) | |
| node_mask = node_mask.view(bs, n_nodes, 1) # (B, N, 1) | |
| if torch.any(torch.isnan(vel)) or torch.any(torch.isnan(h_final)): | |
| raise utils.FoundNaNException(vel, h_final) | |
| if self.centering: | |
| vel = utils.remove_mean_with_mask(vel, node_mask) | |
| return torch.cat([vel, h_final], dim=2) | |
| def get_dist_edges_4A(x, node_mask, batch_mask): | |
| node_mask = node_mask.squeeze().bool() | |
| batch_adj = (batch_mask[:, None] == batch_mask[None, :]) | |
| nodes_adj = (node_mask[:, None] & node_mask[None, :]) | |
| dists_adj = (torch.cdist(x, x) <= 4) | |
| rm_self_loops = ~torch.eye(x.size(0), dtype=torch.bool, device=x.device) | |
| adj = batch_adj & nodes_adj & dists_adj & rm_self_loops | |
| edges = torch.stack(torch.where(adj)) | |
| return edges | |
| def get_dist_edges(self, x, node_mask, batch_mask, linker_mask, fragment_only_mask, pocket_only_mask): | |
| node_mask = node_mask.squeeze().bool() | |
| linker_mask = linker_mask.squeeze().bool() & node_mask | |
| fragment_only_mask = fragment_only_mask.squeeze().bool() & node_mask | |
| pocket_only_mask = pocket_only_mask.squeeze().bool() & node_mask | |
| ligand_mask = linker_mask | fragment_only_mask | |
| # General constrains: | |
| batch_adj = (batch_mask[:, None] == batch_mask[None, :]) | |
| nodes_adj = (node_mask[:, None] & node_mask[None, :]) | |
| rm_self_loops = ~torch.eye(x.size(0), dtype=torch.bool, device=x.device) | |
| constraints = batch_adj & nodes_adj & rm_self_loops | |
| # Ligand atoms – fully-connected graph | |
| ligand_adj = (ligand_mask[:, None] & ligand_mask[None, :]) | |
| ligand_interactions = ligand_adj & constraints | |
| # Pocket atoms - within 4A | |
| pocket_adj = (pocket_only_mask[:, None] & pocket_only_mask[None, :]) | |
| pocket_dists_adj = (torch.cdist(x, x) <= 4) | |
| pocket_interactions = pocket_adj & pocket_dists_adj & constraints | |
| # Pocket-ligand atoms - within 10A | |
| pocket_ligand_cutoff = 4 if self.graph_type == 'FC-4A' else 10 | |
| pocket_ligand_adj = (ligand_mask[:, None] & pocket_only_mask[None, :]) | |
| pocket_ligand_adj = pocket_ligand_adj | (pocket_only_mask[:, None] & ligand_mask[None, :]) | |
| pocket_ligand_dists_adj = (torch.cdist(x, x) <= pocket_ligand_cutoff) | |
| pocket_ligand_interactions = pocket_ligand_adj & pocket_ligand_dists_adj & constraints | |
| adj = ligand_interactions | pocket_interactions | pocket_ligand_interactions | |
| edges = torch.stack(torch.where(adj)) | |
| return edges | |