-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_to_streamlines.py
84 lines (55 loc) · 2.23 KB
/
model_to_streamlines.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import numpy as np
import nibabel as nib
from dipy.reconst.dti import TensorModel
from dipy.reconst.gqi import GeneralizedQSamplingModel
from dipy.reconst.dsi import DiffusionSpectrumDeconvModel
from dipy.data import get_sphere
from load_data import (get_train_dsi,
get_train_mask,
get_train_rois)
from dipy.tracking.eudx import EuDX
from dipy.reconst.odf import peaks_from_model
data, affine, gtab = get_train_dsi(30)
mask, _ = get_train_mask()
rois, _ = get_train_rois()
#data = data[25-5:25+5, 25-5:25+5, 25-5:25+5]
#mask = mask[25-5:25+5, 25-5:25+5, 25-5:25+5]
# data = data[25 - 10:25 + 10, 25 - 10:25 + 10, 25]
# data = data[:, :, 25]
fa_mask = TensorModel(gtab).fit(data, mask).fa > 0.1
gqi_model = DiffusionSpectrumDeconvModel(gtab)
"""
gqi_model = GeneralizedQSamplingModel(gtab,
method='gqi2',
sampling_length=3,
normalize_peaks=False)
"""
sphere = get_sphere('symmetric724')
peaks = peaks_from_model(gqi_model, data, sphere, 0.35, 30,
mask=fa_mask, normalize_peaks=True)
#seeds = np.vstack(np.where((rois >= 1) & (rois <= 6))).T
# seeds = np.vstack(np.where((rois == 1) | (rois == 2))).T
seeds = np.vstack(np.where(rois > 0)).T
seeds = np.ascontiguousarray(seeds)
eu = EuDX(peaks.peak_values[...,0], peaks.peak_indices[...,0],
seeds = seeds,
odf_vertices=sphere.vertices,
a_low = 0.2,
step_sz = 0.5)
from dipy.tracking.metrics import length
streamlines = [s for s in eu if length(s) > 10]
from dipy.viz import fvtk
r = fvtk.ren()
from dipy.viz.colormap import line_colors
# from dipy.segment.quickbundles import QuickBundles
# qb = QuickBundles(streamlines, 10., 18)
# fvtk.add(r, fvtk.line(qb.centroids,
# line_colors(qb.centroids)))
fvtk.add(r, fvtk.line(streamlines,
line_colors(streamlines)))
from show_streamlines import show_gt_streamlines
from load_data import get_train_gt_fibers
streamlines_gt, radii_gt = get_train_gt_fibers()
streamlines_gt = [s + np.array([24.5, 24.5, 24.5]) for s in streamlines_gt]
show_gt_streamlines(streamlines_gt, radii_gt, r=r)
fvtk.show(r)