-
Notifications
You must be signed in to change notification settings - Fork 1
/
splice7.0.py
278 lines (245 loc) · 11.4 KB
/
splice7.0.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
'''
GA数据的云检测
'''
#%%
import numpy as np
from pyhdf.SD import SD, SDC
import matplotlib.pyplot as plt
import os
import cv2
from libtiff import TIFF
import gdal
from gdalconst import *
import pandas as pd
from sklearn.externals import joblib
from sklearn.mixture import GMM, GaussianMixture
#%%
def detectCloud(b):
###
cloud = (b[:, :, 3] > 2000) & (b[:, :, 0] > 1500)
noCloud = (~cloud) & (((b[:, :, 0] < 400) & (b[:, :, 1] < 800)) | (b[:, :, 5] < 1000))
masker = cloud * 1 + noCloud * 2
img = np.clip(b[:, :, :3]*1e-4, a_min=0, a_max=1).astype(np.float32)
mark4 = cv2.watershed(np.uint8(img*255) ,masker) # 返回-1是边界,0是不确定,剩下的就是目标
flag = mark4 != 2
###
# flag = (b[:, :, 3] > 1200) & (b[:, :, 0] > 1400)
# flag = (b & 3 == 1) | (b & 3 == 2)
# flag = gmm(b)
return flag
def gauss(x,mu,sigma,A):
return A*exp(-(x-mu)**2/2/sigma**2)
def gmmHist(b):
flag = np.where(np.sum(b[:, :, :5] <= -100, axis=-1) > 0, True, False)
light = np.sum(b[:, :, :5], -1) + 300
light = np.where(light < 0, 0, light)
light = light / np.max(light.flatten())
hist, bin = np.histogram(light, 256, (0, 1))
hist[0] = 0
clf = GMM(4)
clf.fit(hist[:, None])
labels = clf.fit_predict(hist[:, None])
minIdx = np.nonzero(labels == labels[-1])[0][1]
cloud = flag | (light >= hist[minIdx])
def gmm(b):
flag = np.where(np.sum(b[:, :, :5] <= -100, axis=-1) > 0, True, False)
clf = GMM(4)
clf.fit(b[:, :, :5].reshape((-1, 5)))
h, w, c = b.shape
labels = clf.predict(b[:, :, :5].reshape((-1, 5)))
idx = np.where(b[:, :, 0] == np.max(b[:, :, 0]))
label = clf.predict(b[idx[0][0], idx[1][0], :5])
cloud = flag | (labels.reshape((h, w)) == label)
return cloud
def findBetter(GA, tifList):
blue = np.zeros((2400, 2400), np.float32)
blueList = []
zenithList = []
flag = []
# dataset = gdal.Open(allW, GA_ReadOnly)
# band_i = dataset.GetRasterBand(1)
# water = band_i.ReadAsArray(0, 0, band_i.XSize, band_i.YSize)
pkl = joblib.load('./pkl/train_model.pkl')
clf = pkl['DT']
for i in range(2):
ds = gdal.Open(GA[i])
zenith = ds.GetRasterBand(2).ReadAsArray()
# zenith = np.where(zenith > 0, zenith, zenith*-1)
# zenith = cv2.resize(zenith.astype(np.int8), None, fx=2, fy=2, interpolation=cv2.INTER_NEAREST).astype(bool)
zenithList.append(zenith)
validFlag = np.ones((2400, 2400), np.bool)
data = np.zeros((2400, 2400, 7), np.int16)
ds = gdal.Open(tifList[i])
for i in range(7):
bi = ds.GetRasterBand(i+1).ReadAsArray()
data[:, :, i] = np.clip(bi, 0, 10000)
data[:, :, i] = np.where(bi < -100, 10000, data[:, :, i])
if i >= 4:
continue
validFlag = validFlag & np.where((bi>=-100)&(bi<10000), True, False)
# qc = ds.GetRasterBand(3).ReadAsArray()
# validFlag = (qc & 3) == 2
state = ds.GetRasterBand(1).ReadAsArray()
NDVI = (data[:, :, 3] - data[:, :, 2]) / (data[:, :, 3] + data[:, :, 2])
maxVis = np.max(data[:, :, :3], axis=-1)
SWIR = np.max(data[:, :, 5:], -1)
SWIR = np.where(SWIR >= -100, SWIR, data[:, :, 2])
NDWI = (maxVis - SWIR) / (np.abs(maxVis) + np.abs(SWIR))
blueList.append(np.where((NDVI > NDWI) & (maxVis > 50), NDVI, NDWI))
# flag.append(clf.predict(data).reshape((2400, 2400) | ~validFlag))
flag.append(detectCloud(data) | ~validFlag)
# flag : 0-not cloud, 1-cloud
switchFlag = (flag[0] == 0) & ((flag[1] == 1) | (zenith[0] < zenith[1]))
flag = flag[0] & flag[1]
# flag = cv2.resize(flag.astype(np.int8), None, fx=2, fy=2, interpolation=cv2.INTER_NEAREST).astype(bool)
# switchFlag = cv2.resize(switchFlag.astype(np.int8), None, fx=2, fy=2, interpolation=cv2.INTER_NEAREST).astype(bool)
blue = np.where(switchFlag, blueList[0], blueList[1])
blue = np.where(flag, -200, blue)
return blue, ~flag * (~switchFlag + 1)
def readImage(img_path):
data = []
# 以只读方式打开遥感影像
dataset = gdal.Open(img_path, GA_ReadOnly)
if dataset is None:
print("Unable to open image file.")
return data
else:
print("Open image file success.")
geoTransform = dataset.GetGeoTransform()
im_proj = dataset.GetProjection() # 获取投影信息
return geoTransform, im_proj
def writeImage(bands, path, geotrans=None, proj=None):
projection = [
# WGS84坐标系(EPSG:4326)
"""GEOGCS["WGS 84", DATUM["WGS_1984", SPHEROID["WGS 84", 6378137, 298.257223563, AUTHORITY["EPSG", "7030"]], AUTHORITY["EPSG", "6326"]], PRIMEM["Greenwich", 0, AUTHORITY["EPSG", "8901"]], UNIT["degree", 0.01745329251994328, AUTHORITY["EPSG", "9122"]], AUTHORITY["EPSG", "4326"]]""",
# Pseudo-Mercator、球形墨卡托或Web墨卡托(EPSG:3857)
"""PROJCS["WGS 84 / Pseudo-Mercator",GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AUTHORITY["EPSG","4326"]],PROJECTION["Mercator_1SP"],PARAMETER["central_meridian",0],PARAMETER["scale_factor",1],PARAMETER["false_easting",0],PARAMETER["false_northing",0],UNIT["metre",1,AUTHORITY["EPSG","9001"]],AXIS["X",EAST],AXIS["Y",NORTH],EXTENSION["PROJ4","+proj=merc +a=6378137 +b=6378137 +lat_ts=0.0 +lon_0=0.0 +x_0=0.0 +y_0=0 +k=1.0 +units=m +nadgrids=@null +wktext +no_defs"],AUTHORITY["EPSG","3857"]]"""
]
if bands is None:
return
else:
# 认为各波段大小相等,所以以第一波段信息作为保存
if bands.ndim == 2:
bands = bands[:, :, None]
# 设置影像保存大小、波段数
band1 = bands[:, :, 0]
img_width = band1.shape[1]
img_height = band1.shape[0]
num_bands = bands.shape[2]
# 设置保存影像的数据类型
if 'int8' in band1.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in band1.dtype.name:
datatype = gdal.GDT_Int16
else:
datatype = gdal.GDT_Float32
# 创建文件
# 先创建驱动,再创建相应的栅格数据集
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(path, img_width, img_height, num_bands, datatype)
if dataset is not None:
if geotrans is not None:
dataset.SetGeoTransform(geotrans) # 写入仿射变换参数
if proj is not None:
if proj is 'WGS84' or proj is 'wgs84' or proj is 'EPSG:4326' or proj is 'EPSG-4326' or proj is '4326':
dataset.SetProjection(projection[0]) # 写入投影
elif proj is 'EPSG:3857' or proj is 'EPSG-3857' or proj is '3857':
dataset.SetProjection(projection[1]) # 写入投影
else:
dataset.SetProjection(proj) # 写入投影
for i in range(num_bands):
dataset.GetRasterBand(i + 1).WriteArray(bands[:, :, i])
print("save image success.")
#%%
GAPath = './MOD09GA_ST/'
regionList = ['.h28v05', '.h27v05', '.h28v06', '.h27v04']
tifFileAfterList = [[224, 226], [224, 229], [223, 225], [229, 233]]
tifFileBeforeList = [[217, 213], [212, 207], [219, 207], [218, 207]]
H = W = 2400
tifPath = './MOD09GA/'
GQ = 'D:/Data/MOD09GQ/'
qc = './QC_500m/'
allWList = os.listdir(GQ)
allWList = [i for i in allWList if i[:2] == 'h2']
for i in range(len(regionList)):
tifFileBefore = tifFileBeforeList[i]
tifFileAfter = tifFileAfterList[i]
region = regionList[i]
allW = GQ + [x for x in allWList if x[:6] == region[1:]][0]
for i in range(2):
if i == 1:
date = tifFileBefore
s, e = date[0], date[1]
date = [str(i)+region for i in range(s, e-1, -1)]
name = 'before'
else:
date = tifFileAfter
s, e = date[0], date[1]
date = [str(i)+region for i in range(s, e+1)]
name = 'after'
# 读取所有数据,得到去云之后的图像
landMask = np.zeros((H, W, len(date)), np.int8)
x = y = range(H)
xx, yy = np.meshgrid(x, y)
zz = np.zeros_like(xx, np.int8)
flag = np.ones((H, W), bool)
blue = np.ones_like(landMask, np.float32) * -200
for i, word in enumerate(date):
zenithList = []
tifList = []
for fileName in os.listdir(GAPath):
if word in fileName:
print(fileName)
zenithList.append(GAPath+fileName)
tifList.append(tifPath+fileName)
# find better from MOD and MYD
b, tmpLandMask = findBetter(zenithList, tifList)
blue[:, :, i] = b
zz += (tmpLandMask > 0)
landMask[:, :, i] = tmpLandMask
flag = flag & (tmpLandMask == 0)
print(np.sum(flag))
# print(np.isinf(blue).any())
# 根据蓝波段选择最优值
blue_sort = np.sort(blue, axis=-1) # 从小到大排
blue_sort = blue_sort[:, :, ::-1] # 从大到小排
# medianB = blue_sort[yy, xx, np.where(zz, zz-1, 0)] # 取最小值
# medianB = blue_sort[yy, xx, zz//2] # 取中值中的小值
medianB = blue_sort[yy, xx, 0]
# 得到最优值所在的时间
NO = np.zeros((H, W), np.int8)
flag1 = np.ones((H, W), bool)
for i in range(len(date)):
tmpFlag = (blue[:, :, i] == medianB) & (landMask[:, :, i] > 0)
NO += np.where(tmpFlag & flag1, 2*i+landMask[:, :, i]-1, 0)
flag1 = flag1 & (tmpFlag == 0)
print(np.sum(flag1))
# 根据最优时间合成图像
# bandOrder = [13, 14, 11, 12, 15, 16, 17]
bandsGA = np.zeros((H, W, 7), np.int16)
bandsTmp = np.zeros_like(bandsGA)
for i, word in enumerate(date):
for fileName in os.listdir(tifPath):
if word in fileName:
ds = gdal.Open(tifPath+fileName)
for j in range(7):
band = ds.GetRasterBand(j+1).ReadAsArray()
# if idx == 1:
# band = cv2.resize(band, None, fx=2, fy=2, interpolation=cv2.INTER_NEAREST)
bandsTmp[:, :, j] = np.clip(band, -100, 10000)
sameDateFlag = (NO == 2*i+ (0 if 'MOD'in fileName else 1))
bandsGA += np.int16(bandsTmp * sameDateFlag[:, :, None])
bandsGA = np.where(flag[:, :, None], -200, bandsGA)
maxVis = np.max(bandsGA[:, :, :3], axis=-1)
maxSWIR = np.max(bandsGA[:, :, 5:], axis=-1)
WI = (maxVis > maxSWIR) | (maxVis < 0) & (maxVis >=-100) | (maxVis > bandsGA[:,:,3]) & (bandsGA[:,:,3] >= -100)
# kernel = np.ones((3, 3), np.float32)
# WI = cv2.filter2D(WI.astype(np.uint8), -1, kernel)
# WI = np.where(WI > 2, 200, 0)
# 保存图像,并写入地理信息
geotrans, proj = readImage('D:\\Data\\MOD09GQ\\'+date[0].split('.')[1]+'_AllWDays_percent.tiff')
# geotrans /= np.array([1, 2, 1, 1, 1, 2])
writeImage(bandsGA, './output3/'+date[0].split('.')[1]+'_'+name+'.tiff', \
geotrans=geotrans, proj=proj)
writeImage(WI, './output3/'+date[0].split('.')[1]+'_'+name+'_WI.tiff', \
geotrans=geotrans, proj=proj)