Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New data preprocessing #36

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
5 changes: 5 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"githubPullRequests.ignoredPullRequestBranches": [
"master"
]
}
90 changes: 37 additions & 53 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,27 +1,14 @@
# TGN: Temporal Graph Networks [[arXiv](https://arxiv.org/abs/2006.10637), [YouTube](https://www.youtube.com/watch?v=W1GvX2ZcUmY), [Blog Post](https://towardsdatascience.com/temporal-graph-networks-ab8f327f2efe)]
# Experimentos con Temporal Graph Networks

Dynamic Graph | TGN
:-------------------------:|:-------------------------:
![](figures/dynamic_graph.png) | ![](figures/tgn.png)
## Introducción

Partiendo del artículo [Temporal Graph Networks for Deep Learning on Dynamic Graphs](https://arxiv.org/abs/2006.10637) y del repositorio https://github.com/twitterresearch/tgn, la propuesta consiste en una evaluación experimental de las Temporal Graph
Networks definidas en dicho artículo, analizando las líneas futuras propuestas y valorando
los resultados obtenidos.

## Ejecución de los experimentos


## Introduction

Despite the plethora of different models for deep learning on graphs, few approaches have been proposed thus far for dealing with graphs that present some sort of dynamic nature (e.g. evolving features or connectivity over time).

In this paper, we present Temporal Graph Networks (TGNs), a generic, efficient framework for deep learning on dynamic graphs represented as sequences of timed events. Thanks to a novel combination of memory modules and graph-based operators, TGNs are able to significantly outperform previous approaches being at the same time more computationally efficient.

We furthermore show that several previous models for learning on dynamic graphs can be cast as specific instances of our framework. We perform a detailed ablation study of different components of our framework and devise the best configuration that achieves state-of-the-art performance on several transductive and inductive prediction tasks for dynamic graphs.


#### Paper link: [Temporal Graph Networks for Deep Learning on Dynamic Graphs](https://arxiv.org/abs/2006.10637)


## Running the experiments

### Requirements
### Requerimientos

Dependencies (with python >= 3.7):

Expand All @@ -31,81 +18,78 @@ torch==1.6.0
scikit_learn==0.23.1
```

### Dataset and Preprocessing
### Conjuntos de datos y pre-procesado

#### Conjuntos de datos de TGN

#### Download the public data
Download the sample datasets (eg. wikipedia and reddit) from
[here](http://snap.stanford.edu/jodie/) and store their csv files in a folder named
```data/```.
##### Descargar los datos
Se pueden descargar los conjuntos de datos de wikipedia y reddit desde [aquí](http://snap.stanford.edu/jodie/) y se deben almacenar en las carpetas
```data/tgn_wikipedia``` y ```data/tgn_reddit``` respectivamente.

#### Preprocess the data
We use the dense `npy` format to save the features in binary format. If edge features or nodes
features are absent, they will be replaced by a vector of zeros.
#### Pre-procesamiento de los datos
Se emplean archivos .npy para guardar los datos creados. Si las características de los nodos o aristas están vacías, se rellenarán con 0's.
```{bash}
python utils/preprocess_data.py --data wikipedia --bipartite
python utils/preprocess_data.py --data reddit --bipartite
python3 utils/tgn_preprocess_data.py --data wikipedia --bipartite
python3 utils/tgn_preprocess_data.py --data reddit --bipartite
```

### Entrenamiento del modelo


### Model Training

Self-supervised learning using the link prediction task:
Para la predicción de enlaces:
```{bash}
# TGN-attn: Supervised learning on the wikipedia dataset
python train_self_supervised.py --use_memory --prefix tgn-attn --n_runs 10
python3 tgn_link_prediction.py --use_memory --prefix tgn-attn --n_runs 10

# TGN-attn-reddit: Supervised learning on the reddit dataset
python train_self_supervised.py -d reddit --use_memory --prefix tgn-attn-reddit --n_runs 10
python tgn_link_prediction.py -d reddit --use_memory --prefix tgn-attn-reddit --n_runs 10
```

Supervised learning on dynamic node classification (this requires a trained model from
the self-supervised task, by eg. running the commands above):
Para la clasificación de nodos(se requiere el modelo entrenado en la tarea de predicción de enlaces):
```{bash}
# TGN-attn: self-supervised learning on the wikipedia dataset
python train_supervised.py --use_memory --prefix tgn-attn --n_runs 10
python3 tgn_node_classification.py --use_memory --prefix tgn-attn --n_runs 10

# TGN-attn-reddit: self-supervised learning on the reddit dataset
python train_supervised.py -d reddit --use_memory --prefix tgn-attn-reddit --n_runs 10
python3 tgn_node_classification.py -d reddit --use_memory --prefix tgn-attn-reddit --n_runs 10
```

### Baselines
### JODIE y DyRep

```{bash}
### Wikipedia Self-supervised
### Predicción de enlaces en Wikipedia

# Jodie
python train_self_supervised.py --use_memory --memory_updater rnn --embedding_module time --prefix jodie_rnn --n_runs 10
python3 tgn_link_prediction.py --use_memory --memory_updater rnn --embedding_module time --prefix jodie_rnn --n_runs 10

# DyRep
python train_self_supervised.py --use_memory --memory_updater rnn --dyrep --use_destination_embedding_in_message --prefix dyrep_rnn --n_runs 10
python3 tgn_link_prediction.py --use_memory --memory_updater rnn --dyrep --use_destination_embedding_in_message --prefix dyrep_rnn --n_runs 10


### Reddit Self-supervised
### Predicción de enlaces en Reddit

# Jodie
python train_self_supervised.py -d reddit --use_memory --memory_updater rnn --embedding_module time --prefix jodie_rnn_reddit --n_runs 10
python3 tgn_link_prediction.py -d reddit --use_memory --memory_updater rnn --embedding_module time --prefix jodie_rnn_reddit --n_runs 10

# DyRep
python train_self_supervised.py -d reddit --use_memory --memory_updater rnn --dyrep --use_destination_embedding_in_message --prefix dyrep_rnn_reddit --n_runs 10
python3 tgn_link_prediction.py -d reddit --use_memory --memory_updater rnn --dyrep --use_destination_embedding_in_message --prefix dyrep_rnn_reddit --n_runs 10


### Wikipedia Supervised
### Clasificación de nodos en Wikipedia

# Jodie
python train_supervised.py --use_memory --memory_updater rnn --embedding_module time --prefix jodie_rnn --n_runs 10
python3 tgn_node_classification.py --use_memory --memory_updater rnn --embedding_module time --prefix jodie_rnn --n_runs 10

# DyRep
python train_supervised.py --use_memory --memory_updater rnn --dyrep --use_destination_embedding_in_message --prefix dyrep_rnn --n_runs 10
python3 tgn_node_classification.py --use_memory --memory_updater rnn --dyrep --use_destination_embedding_in_message --prefix dyrep_rnn --n_runs 10


### Reddit Supervised
### Clasificación de nodos en Reddit

# Jodie
python train_supervised.py -d reddit --use_memory --memory_updater rnn --embedding_module time --prefix jodie_rnn_reddit --n_runs 10
python3 tgn_node_classification.py -d reddit --use_memory --memory_updater rnn --embedding_module time --prefix jodie_rnn_reddit --n_runs 10

# DyRep
python train_supervised.py -d reddit --use_memory --memory_updater rnn --dyrep --use_destination_embedding_in_message --prefix dyrep_rnn_reddit --n_runs 10
python3 tgn_node_classification.py -d reddit --use_memory --memory_updater rnn --dyrep --use_destination_embedding_in_message --prefix dyrep_rnn_reddit --n_runs 10
```


Expand Down
104 changes: 64 additions & 40 deletions model/tgn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,23 @@
from modules.memory_updater import get_memory_updater
from modules.embedding_module import get_embedding_module
from model.time_encoding import TimeEncode
from modules.feature_embedding import get_feature_embedding


class TGN(torch.nn.Module):
def __init__(self, neighbor_finder, node_features, edge_features, device, n_layers=2,
n_heads=2, dropout=0.1, use_memory=False,
memory_update_at_start=True, message_dimension=100,
memory_dimension=500, embedding_module_type="graph_attention",
message_function="mlp",
message_function="identity",
mean_time_shift_src=0, std_time_shift_src=1, mean_time_shift_dst=0,
std_time_shift_dst=1, n_neighbors=None, aggregator_type="last",
memory_updater_type="gru",
use_destination_embedding_in_message=False,
use_source_embedding_in_message=False,
dyrep=False):
dyrep=False,
feature_embedding_type="identity",
feature_dimension=50):
super(TGN, self).__init__()

self.n_layers = n_layers
Expand All @@ -43,6 +46,8 @@ def __init__(self, neighbor_finder, node_features, edge_features, device, n_laye
self.use_destination_embedding_in_message = use_destination_embedding_in_message
self.use_source_embedding_in_message = use_source_embedding_in_message
self.dyrep = dyrep
self.feature_embedding_type = feature_embedding_type
self.feature_dimension = feature_dimension

self.use_memory = use_memory
self.time_encoder = TimeEncode(dimension=self.n_node_features)
Expand All @@ -65,7 +70,7 @@ def __init__(self, neighbor_finder, node_features, edge_features, device, n_laye
message_dimension=message_dimension,
device=device)
self.message_aggregator = get_message_aggregator(aggregator_type=aggregator_type,
device=device)
device=device, raw_message_dimension=raw_message_dimension)
self.message_function = get_message_function(module_type=message_function,
raw_message_dimension=raw_message_dimension,
message_dimension=message_dimension)
Expand All @@ -74,6 +79,9 @@ def __init__(self, neighbor_finder, node_features, edge_features, device, n_laye
message_dimension=message_dimension,
memory_dimension=self.memory_dimension,
device=device)
self.feature_embedding = get_feature_embedding(module_type=self.feature_embedding_type,
raw_features_dimension=self.n_edge_features,
features_dimension=self.feature_dimension)

self.embedding_module_type = embedding_module_type

Expand Down Expand Up @@ -124,25 +132,22 @@ def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_
if self.memory_update_at_start:
# Update memory for all nodes with messages stored in previous batches
memory, last_update = self.get_updated_memory(list(range(self.n_nodes)),
self.memory.messages)
self.memory.messages)
else:
memory = self.memory.get_memory(list(range(self.n_nodes)))
last_update = self.memory.last_update

### Compute differences between the time the memory of a node was last updated,
### and the time for which we want to compute the embedding of a node
source_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
source_nodes].long()
source_time_diffs = (source_time_diffs - self.mean_time_shift_src) / self.std_time_shift_src
destination_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
destination_nodes].long()
destination_time_diffs = (destination_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst
negative_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
negative_nodes].long()
negative_time_diffs = (negative_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst

time_diffs = torch.cat([source_time_diffs, destination_time_diffs, negative_time_diffs],
dim=0)
### Compute differences between the time the memory of a node was last updated,
### and the time for which we want to compute the embedding of a node
source_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[source_nodes].long()
source_time_diffs = (source_time_diffs - self.mean_time_shift_src) / self.std_time_shift_src
destination_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[destination_nodes].long()
destination_time_diffs = (destination_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst
negative_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[negative_nodes].long()
negative_time_diffs = (negative_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst

time_diffs = torch.cat([source_time_diffs, destination_time_diffs, negative_time_diffs],
dim=0)

# Compute the embeddings using the embedding module
node_embedding = self.embedding_module.compute_embedding(memory=memory,
Expand All @@ -162,22 +167,22 @@ def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_
# new messages for them)
self.update_memory(positives, self.memory.messages)

assert torch.allclose(memory[positives], self.memory.get_memory(positives), atol=1e-5), \
"Something wrong in how the memory was updated"
# assert torch.allclose(memory[positives], self.memory.get_memory(positives), atol=1e-5), \
# "Something wrong in how the memory was updated"

# Remove messages for the positives since we have already updated the memory using them
self.memory.clear_messages(positives)

unique_sources, source_id_to_messages = self.get_raw_messages(source_nodes,
source_node_embedding,
destination_nodes,
destination_node_embedding,
edge_times, edge_idxs)
source_node_embedding,
destination_nodes,
destination_node_embedding,
edge_times, edge_idxs)
unique_destinations, destination_id_to_messages = self.get_raw_messages(destination_nodes,
destination_node_embedding,
source_nodes,
source_node_embedding,
edge_times, edge_idxs)
destination_node_embedding,
source_nodes,
source_node_embedding,
edge_times, edge_idxs)
if self.memory_update_at_start:
self.memory.store_raw_messages(unique_sources, source_id_to_messages)
self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)
Expand All @@ -191,6 +196,7 @@ def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_
negative_node_embedding = memory[negative_nodes]

return source_node_embedding, destination_node_embedding, negative_node_embedding


def compute_edge_probabilities(self, source_nodes, destination_nodes, negative_nodes, edge_times,
edge_idxs, n_neighbors=20):
Expand Down Expand Up @@ -219,40 +225,58 @@ def compute_edge_probabilities(self, source_nodes, destination_nodes, negative_n
return pos_score.sigmoid(), neg_score.sigmoid()

def update_memory(self, nodes, messages):
"""
1. Agrega los mensajes pertenecientes a los mismos nodos.
2. Calcula los mensajes únicos de cada nodo.
3. Actualiza la memoria con los mensajes agregados y calculados.
"""
# Aggregate messages for the same nodes
unique_nodes, unique_messages, unique_timestamps = \
self.message_aggregator.aggregate(
nodes,
messages)

if len(unique_nodes) > 0:
unique_nodes, unique_messages, unique_timestamps = self.message_aggregator.aggregate(nodes, messages)

# Compute messages from raw messages
if len(unique_messages) > 0:
unique_messages = self.message_function.compute_message(unique_messages)

# Update the memory with the aggregated messages
self.memory_updater.update_memory(unique_nodes, unique_messages,
timestamps=unique_timestamps)

def update_memory_new(self, positives, updated_memory, last_update):


# Update the memory with the aggregated messages
#self.memory_updater.update_memory_new(unique_nodes, updated_memory, last_update)
return

def get_updated_memory(self, nodes, messages):
"""
1. Agrega los mensajes pertenecientes a los mismos nodos.
2. Calcula los mensajes únicos de cada nodo.
3. Actualiza la memoria con los mensajes agregados y calculados.
4. Devuelve la memoria actualizada y los tiempos de la última actualización.
"""

# Aggregate messages for the same nodes
unique_nodes, unique_messages, unique_timestamps = \
self.message_aggregator.aggregate(
nodes,
messages)
unique_nodes, unique_messages, unique_timestamps = self.message_aggregator.aggregate(nodes, messages)

print(len(unique_nodes))

if len(unique_nodes) > 0:
unique_messages = self.message_function.compute_message(unique_messages)

updated_memory, updated_last_update = self.memory_updater.get_updated_memory(unique_nodes,
unique_messages,
timestamps=unique_timestamps)

return updated_memory, updated_last_update

def get_raw_messages(self, source_nodes, source_node_embedding, destination_nodes,
destination_node_embedding, edge_times, edge_idxs):
edge_times = torch.from_numpy(edge_times).float().to(self.device)
edge_features = self.edge_raw_features[edge_idxs]

# Aprendizaje de características de aristas
edge_features = self.feature_embedding.compute_features(edge_features)

source_memory = self.memory.get_memory(source_nodes) if not \
self.use_source_embedding_in_message else source_node_embedding
destination_memory = self.memory.get_memory(destination_nodes) if \
Expand All @@ -264,7 +288,7 @@ def get_raw_messages(self, source_nodes, source_node_embedding, destination_node

source_message = torch.cat([source_memory, destination_memory, edge_features,
source_time_delta_encoding],
dim=1)
dim=1)
messages = defaultdict(list)
unique_sources = np.unique(source_nodes)

Expand Down
Loading