Skip to content

Commit

Permalink
updated dragonn workshop tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
jisraeli committed Aug 24, 2016
1 parent e5f5ff8 commit cf033e3
Show file tree
Hide file tree
Showing 2 changed files with 1,210 additions and 239 deletions.
27 changes: 12 additions & 15 deletions dragonn/tutorial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,21 +185,16 @@ def plot_SequenceDNN_layer_outputs(dnn, simulation_data):
ax2.axvspan(conv_output_start, conv_output_stop, color='grey', alpha=0.5)


def interpret_SequenceDNN_distributed(dnn, simulation_data, plot_layer_outputs=False):
def interpret_SequenceDNN_filters(dnn, simulation_data):
print("Plotting simulation motifs...")
plot_motifs(simulation_data)
plt.show()
print("Visualizing convolutional sequence filters in SequenceDNN...")
plot_sequence_filters(dnn)
plt.show()
if plot_layer_outputs:
print("%s %s" % ("Plotting outputs of convolutional and max pooling layer",
"for a positive simulation example..."))
plot_SequenceDNN_layer_outputs(dnn, simulation_data)
plt.show()


def interpret_SequenceDNN_integrative(dnn, simulation_data):
def interpret_data_with_SequenceDNN(dnn, simulation_data):
# get a positive and a negative example from the simulation data
pos_indx = np.where(simulation_data.y_valid==1)[0][0]
pos_X = simulation_data.X_valid[pos_indx:(pos_indx+1)]
Expand Down Expand Up @@ -277,19 +272,21 @@ def interpret_SequenceDNN_integrative(dnn, simulation_data):
if score_type=='Motif Scores':
scores_to_plot = scores[0, _i, :]
else:
scores_to_plot = scores.squeeze(axis=2)
scores_to_plot = scores[0, 0, 0, :]
if motif_label not in motif_labels_cache:
motif_labels_cache.append(motif_label)
add_legend = True
motif_color = motif_colors[motif_labels_cache.index(motif_label)]
ax.plot(scores_to_plot, label=motif_label, c=motif_color)
if add_legend:
leg = ax.legend(loc=[0,0.85], frameon=False, fontsize=font_size,
ncol=3, handlelength=-0.5)
for legobj in leg.legendHandles:
legobj.set_color('w')
for _i, text in enumerate(leg.get_texts()):
text.set_color(motif_color)
if add_legend:
leg = ax.legend(loc=[0,0.85], frameon=False, fontsize=font_size,
ncol=3, handlelength=-0.5)
for legobj in leg.legendHandles:
legobj.set_color('w')
for _j, text in enumerate(leg.get_texts()):
text_color = motif_colors[
motif_labels_cache.index(motif_label_dict[score_type][_j])]
text.set_color(text_color)
for motif_site in motif_sites[key]:
ax.axvspan(motif_site - highlight_width, motif_site + highlight_width,
color='grey', alpha=0.1)
Loading

0 comments on commit cf033e3

Please sign in to comment.