Commit a5c2426e authored by Shanyu Zhou's avatar Shanyu Zhou
Browse files

Upload New File

parent 0f5ad242
# -*- coding: utf-8 -*-
"""
Created on 2020-Oct-23 2:39 PM
@Author : Shanyu Zhou
"""
import os
import numpy as np
from osgeo import gdal
from scipy import interpolate
import time
import pandas as pd
# worldviewvl = np.array([1210, 1570, 1660, 1730, 2165, 2205, 2260, 2330])
# ********************** Image handling start **********************
class imagecube(object):
def __init__(self):
self.ncol = None
self.nrow = None
self.nbands = None
# self.byteorder = None
self.wvl = np.array([])
self.unit = None
self.fwhm = np.array([])
self.datacube = np.array([])
self.mapinfo = None
self.projinfo = None
self.source = ''
def readimage(filen):
filename = []
path = os.path.split(filen)[0]
filenn = os.path.split(filen)[1]
if filenn.find('.') == -1:
for i in os.listdir(path):
if i.find(filenn) != -1:
if i.find('hdr') == -1 and i.find('.enp') == -1:
filename = path +'/' + i
elif filenn.find('hdr') != -1:
for i in os.listdir(path):
if i.split('.')[0]==filenn.split('.')[0]:
if i.find('hdr') == -1 and i.find('.enp') == -1:
filename = path +'/' + i
else:
filename = filen
dataset = gdal.Open(filename)
HSI_data = dataset.ReadAsArray()
return HSI_data
def readhdr(filen):
if filen.find('.') == -1:
hdrname = filen +'.hdr'
elif filen.find('hdr') == -1:
hdrname = filen.split('.')[0]+'.hdr'
else:
hdrname = filen
HSI_hdr = imagecube()
HSI_hdr.source = filen
HSI_hdr.datacube = readimage(filen)
f = open(hdrname,'r')
fhdr = f.readlines()
f.close()
n = 0
st,ed = [],[]
for line in fhdr:
if line.find('{')>=0:
st.append(n)
if line.find('}')>=0:
ed.append(n)
n=n+1
st,ed = np.array(st),np.array(ed)
for i in range(0,len(fhdr)):
line = fhdr[i]
if line.find('=')>=0:
hsitem = line.split('=')[0]
hsdata = line.split('=')[1]
if hsitem.find('samp')>=0:
HSI_hdr.ncol = int(hsdata)
elif hsitem.find('lines')>=0:
HSI_hdr.nrow = int(hsdata)
elif hsitem.strip() == 'bands':
HSI_hdr.nbands = int(hsdata)
elif hsitem.strip() == 'wavelength':
if hsdata.find('{')>=0 and hsdata.find('}')>=0:
HSI_hdr.wvl = np.array(list(map(float,hsdata.strip()[1:-1].split(','))))
else:
toline = ''
for item in fhdr[i:ed[np.where(st == i)][0]+1]:
toline = toline+item.strip('\n')
hsdata = toline.split('=')[1]
HSI_hdr.wvl = np.array(list(map(float, hsdata.strip()[1:-1].split(','))))
elif hsitem.find('unit')>=0:
HSI_hdr.unit = hsdata
elif hsitem.find('fwhm') >= 0:
toline = ''
for item in fhdr[i:ed[np.where(st == i)][0] + 1]:
toline = toline + item.strip('\n')
hsdata = toline.split('=')[1]
HSI_hdr.fwhm = np.array(list(map(float, hsdata.strip()[1:-1].split(','))))
elif hsitem.find('map info') >=0:
HSI_hdr.mapinfo = hsdata
elif hsitem.find('projection info') >= 0:
HSI_hdr.projinfo = hsdata
return HSI_hdr
def squez(dataset):
if len(dataset.shape)<3:
return np.reshape(dataset, [dataset.shape[0] * dataset.shape[1]])
elif len(dataset.shape) ==3:
return np.reshape(dataset, [dataset.shape[0], dataset.shape[1] * dataset.shape[2]])
def recover(dataset,refer):
return np.reshape(dataset, refer.shape)
def buildimagecube(dc,wvl):
imcb = imagecube()
imcb.datacube = dc
imcb.wvl = wvl
if len(dc.shape) ==3:
imcb.nbands = dc.shape[0]
imcb.ncol = dc.shape[2]
imcb.nrow = dc.shape[1]
elif len(dc.shape) ==2:
imcb.nbands = 1
imcb.ncol = dc.shape[1]
imcb.nrow = dc.shape[0]
return imcb
def writebsq(hhdr,savename,savepath ='./',classnames=[]):
if not os.path.exists(savepath):
os.makedirs(savepath)
if savepath[-1]!='/':
savepath = savepath+'/'
classflag = 0
if len(classnames)>0:
classflag =1
classes = len(set(squez(hhdr.datacube)))
clu = classcolorlib(classes)
hhdr.datacube = np.float64(hhdr.datacube)
savebsq = savepath+savename
savehdr = savebsq.replace('.bsq','')+'.hdr'
empbsq = hhdr.datacube
if not empbsq.flags['C_CONTIGUOUS']:
empbsq = np.ascontiguousarray(empbsq)
with open(savebsq,'wb') as bsqf:
if hhdr.nbands !=1:
for i in range(0,hhdr.nbands):
bsqf.write(empbsq[i,:,:])
else:
bsqf.write(empbsq[ :, :])
empbsq.astype('float64')
# dt = empbsq.dtype
# if (dt == "uint8"):
# daty = 1
# elif (dt == "int16"):
# daty = 2
# elif (dt == "int32"):
# daty = 3
# elif (dt == "float32"):
# daty = 4
# elif (dt == "float64"):
daty = 5
# elif (dt == "complex64"):
# daty = 6
# elif (dt== "complex128"):
# daty = 9
# elif (dt== "uint16"):
# daty = 12
# elif (dt== "uint32"):
# daty = 13
# elif (dt== "int64"):
# daty = 14
# elif (dt== "uint64"):
# daty = 15
description = 'classification'
with open(savehdr,'w') as hdrf:
hdrf.writelines('ENVI\n')
hdrf.writelines("description = {\n")
hdrf.writelines(" " + description + " [%s]}\n" % (time.ctime()))
hdrf.writelines("samples = " + str(hhdr.ncol) + "\n")
hdrf.writelines("lines = " + str(hhdr.nrow) + "\n")
hdrf.writelines("bands = " + str(hhdr.nbands) + "\n")
hdrf.writelines("header offset = 0\n")
if classflag == 1:
hdrf.writelines("file type = ENVI Classification\n")
hdrf.writelines("data type = "+ str(daty) +"\n")
hdrf.writelines("interleave = bsq\n")
hdrf.writelines("byte order = 0\n")
if classflag == 1:
hdrf.writelines("classes = "+str(classes)+"\n")
hdrf.writelines("class names = {\n"+", ".join(classnames)+"}\n")
hdrf.writelines("class lookup = {\n"+str(clu).strip('[').strip(']')+"}\n")
if len(hhdr.wvl)!=0:
hdrf.writelines("wavelength units = Nanometers\n")
hdrf.writelines("wavelength = " + str(list(hhdr.wvl)).replace('[','{').replace(']','}') + "\n")
if not hhdr.mapinfo is None:
hdrf.writelines("map info = "+hhdr.mapinfo+"\n")
if not hhdr.projinfo is None:
hdrf.writelines("projection info = "+hhdr.projinfo+"\n")
def classcolorlib(max):
clu = [ 0, 0, 0, 255, 0, 0, 0, 255, 0, 0, 0, 255, 255, 255, 0, 0, 255, 255, 255, 0, 255, 176, 48, 96, 46, 139, 87, 160, 32, 240, 255, 127, 80, 127, 255, 212, 218, 112, 214, 160, 82, 45, 127, 255, 0, 216, 191, 216, 238, 0, 0, 205, 0, 0, 139, 0, 0, 0, 238, 0, 0, 205, 0, 0, 139, 0, 0, 0, 238, 0, 0, 205, 0, 0, 139, 238, 238, 0, 205, 205, 0, 139, 139, 0, 0, 238, 238, 0, 205, 205, 0, 139, 139, 238, 0, 238, 205, 0, 205, 139, 0, 139, 238, 48, 167, 205, 41, 144, 139, 28, 98, 145, 44, 238, 125, 38, 205, 85, 26, 139, 255, 165, 0, 238, 154, 0, 205, 133, 0, 139, 90, 0, 238, 121, 66, 205, 104, 57, 139, 71, 38, 238, 210, 238, 205, 181, 205, 255, 0, 0, 0, 255, 0, 0, 0, 255, 255, 255, 0]
return clu[0:3*max]
# ********************** Image handling end **********************
# ********************** spectral simulation start **********************
def GenerateSpectralFilterSRF(wvl_in,SRFarray,SRFwvl):
# wvl_in = input wavelength band centres [ndarray 1D]
# wvl_out = output / filtered wavelength band centres [ndarray 1D]
# SRFarray = SRFarray sum to one [float/int]
# return [ndarray 2D] shape = (len(wvl_out),len(wvl_in))
Weights = np.array([interpolate.interp1d(SRFwvl,y)(wvl_in) for y in SRFarray])
# Weights = np.exp(-2*(np.power(Win-Wout,2.0)*np.log(4.0)/(np.power(FWHM,2.0))))
for i in range(Weights.shape[0]):
Weights[i] = Weights[i]/np.sum(Weights[i])
return Weights
def specsim(inputf,referf='',saven=''):
inputhdr = readhdr(inputf)
if np.max(inputhdr.wvl)<1000:
inputhdr.wvl = inputhdr.wvl*1000
if len(saven) == 0:
if len(referf)==0:
saven = os.path.split(inputf)[1].split('.')[0]+'_wv3'
else:
saven = os.path.split(inputf)[1].split('.')[0] + '_simulation'
savepath = os.path.split(inputf)[0] + '/simulated_wv3/'
if not os.path.exists(savepath):
os.makedirs(savepath)
if len(referf)!=0:
referhdr = readhdr(referf)
refweight = GenerateSpectralFilterGaussian(inputhdr.wvl, referhdr.wvl, referhdr.fwhm)
refer = np.zeros([referhdr.nbands, inputhdr.datacube.shape[1], inputhdr.datacube.shape[2]])
outputdata = recover(np.dot(refweight,squez(inputhdr.datacube)),refer)
outputhdr = imagecube()
outputhdr.datacube,outputhdr.wvl,outputhdr.nbands,outputhdr.ncol,outputhdr.nrow = outputdata,referhdr.wvl,referhdr.nbands,inputhdr.ncol,inputhdr.nrow
writebsq(outputhdr,saven,savepath)
elif len(referf)==0:
basepath = os.getcwd()+'/wv3cfg/'
wv3srf = basepath + 'WV03_SRF.xlsx'
worldviewvl = np.array([1210, 1570, 1660, 1730, 2165, 2205, 2260, 2330])
# scope = locals()
wv3 = pd.read_excel(wv3srf)
SRF = np.array([eval('wv3.SWIR' + str(i) + '.to_numpy()/np.sum(wv3.SWIR' + str(i) + '.to_numpy())',{'np':np},{"wv3":wv3}) for i in range(1, 9)])
SRFWVL = wv3.Wavelength.to_numpy()
wvweight = GenerateSpectralFilterSRF(inputhdr.wvl, SRF, SRFWVL)
refer = np.zeros([len(worldviewvl), inputhdr.datacube.shape[1], inputhdr.datacube.shape[2]])
outputdata = recover(np.dot(wvweight, squez(inputhdr.datacube)), refer)
inputhdr.datacube = outputdata
inputhdr.wvl = np.array(worldviewvl)
inputhdr.nbands = len(worldviewvl)
inputhdr.source = savepath+saven
writebsq(inputhdr, saven, savepath)
return inputhdr, savepath
def multispecim(path):
filelist = [i for i in os.listdir(path) if i.find('hdr') > 0]
for i in filelist:
print('Converting ' + i + ' to Worldview-3')
_,wvpath = specsim(path + i)
print('Converted data saved in '+ wvpath)
return wvpath
# ********************** spectral simulation end **********************
# ********************** classify tree start **********************
def getMin(x, y):
minXs = []
minYs = []
for index in range(1, x.size - 1):
if (y[index]-y[index-1]) <=0 and (y[index + 1] - y[index])>0:
minXs.append(x[index])
minYs.append(y[index])
return minXs, minYs
def getMax(x, y):
maxXs = []
maxYs = []
maxXs.append(x[0])
maxYs.append(y[0])
for index in range(1, x.size - 1):
if y[index] - y[index - 1] > 0 and y[index + 1] - y[index] <= 0:
maxXs.append(x[index])
maxYs.append(y[index])
maxXs.append(x[x.size - 1])
maxYs.append(y[x.size - 1])
return maxXs, maxYs
def findloc(wvlen,hdrwvl):
wvloc = np.where(abs(hdrwvl-wvlen) == min(abs(hdrwvl-wvlen)))
return wvloc[0][0]
def nhi(r1,r2,r3,wvl,dc,isvflag = 0):
if len(dc.shape) ==3:
d1 = dc[findloc(r1,wvl), :, :]
d2 = dc[findloc(r2,wvl), :, :]
d3 = dc[findloc(r3,wvl), :, :]
hii = 1 - d2 / ((r2 - r1) * (d3 - d1) / (r3 - r1) + d1)
elif len(dc.shape) == 1:
d1 = dc[findloc(r1, wvl)]
d2 = dc[findloc(r2, wvl)]
d3 = dc[findloc(r3, wvl)]
if isvflag == 0:
hii = 1 - d2 / ((r2 - r1) * (d3 - d1) / (r3 - r1) + d1)
else:
addn = np.max(dc) - np.min(dc)
d1, d2, d3 = addn - d1, addn - d2, addn - d3
hii = 1 - d2 / ((r2 - r1) * (d3 - d1) / (r3 - r1) + d1)
return hii
def getclassname(dc):
allcn = ['Background','methyl','C2','C2PS','C3','mixing methyl']
locs, iclass = zip(*[[i,allcn[int(i)]] for i in set(squez(dc))])
if len(iclass)<len(allcn):
for count, value in enumerate(locs):
if count != value:
# print(count,value)
dc[np.where(dc==value)]=count
return dc,iclass
def localtime():
prefixx = time.ctime()[4:19].replace(':', '_').replace(' ', '_')
return prefixx
def slope(x1, y1, x2, y2):
m = (y2-y1)/(x2-x1)
return m
def mkmask(hdr,thre=7000):
sumt = np.sum(hdr.datacube, axis=0)
sumt[np.where(sumt <= thre)] = 0
sumt[np.where(sumt > thre)] = 1
return sumt
def wvplcl(hdr, th1=0.1, th3=0.05, th4=0.02, mask=None):
if mask is None:
mask = []
dc = hdr.datacube
x = hdr.wvl
opt = np.zeros_like(dc[0,:,:])
if len(mask)<=0:
mask = np.zeros_like(opt)+1
print('----------start classification-----------')
xloc,yloc = (mask>0).nonzero()
start = time.time()
for pix, piy in zip(xloc,yloc):
# print(pix,piy)
y = dc[:,pix,piy]
minx,miny = getMin(x,y)
maxx,maxy = getMax(x,y)
k0 = slope(x[0],y[0],x[1],y[1])
k1 = slope(x[3],y[3],x[4],y[4])
# k3 = slope(x[0], y[0], x[1], y[1])
k4 = slope(x[-4], y[-4], x[-1], y[-1])
if y[0]!=np.min(y):
if y[findloc(1570, x)]>y[findloc(1730, x)]:
if 1730 in minx and not (2165 in minx or 2205 in minx) and (nhi(1570,1660,1730,x,y)<0 or (nhi(1570,1660,1730,x,y) >=0 and nhi(2165, 2205, 2260,x,y)>=0)):
if (2165 in maxx or 2205 in maxx):
if (y[0] < y[1] or (y[0] > y[1] and (y[0] - y[1]) / (np.max(y) - np.min(y)) <=0.5)):
if y[-4]>y[-1] and y[3]>y[-1]and y[0]>y[-1]:
if y[3]<=y[0]:
opt[pix,piy] = 1
elif (y[3]-y[0])/(np.max(y)-np.min(y))<0.05:
opt[pix, piy] = 1
elif nhi(1210, 1570, 1660, x, y) < 0 and nhi(1730, 2165, 2205, x, y)>0 and (nhi(1570, 1660, 1730, x, y) >= 0 or (nhi(1570, 1660, 1730, x, y) < 0 and not (2260 in minx) and(y[2]-y[3])< (y[3]-y[4]))):
if 2165 in minx or 1660 in minx or 1730 in minx:
if 2205 in maxx:
if 1660 in minx:
opt[pix, piy] = 2 #
elif nhi(1730, 2165, 2205, x, y)>=th4 and y[3]<y[0]:
opt[pix, piy] = 4
if 2260 in maxx:
if nhi(1570, 1660, 1730, x, y)>=th3 and k0/k1<2 and y[0]>=y[2]:
opt[pix, piy] = 2
elif nhi(1570, 1730, 2165, x, y) > th1 and k1>k4 and not (2165 in minx or 2205 in minx):
if not (2165 in maxx) and not (2205 in maxx) and y[-4] > y[-1] and y[3] > y[-1] and y[0] > y[-1]:
if (nhi(1570,1660,1730,x,y)<0 and nhi(2165, 2205, 2260,x,y)<=0) or (nhi(1570,1660,1730,x,y)>=0 and nhi(2165, 2205, 2260,x,y)>=0):
if nhi(1210, 1570, 1660, x, y) < 0 :
if y[3] < y[0]:
opt[pix, piy] = 5
elif (y[3] - y[0]) / (np.max(y) - np.min(y)) < 0.05:
opt[pix, piy] = 5
elif np.abs(y[0] - y[2]) < np.abs(y[-1] - y[-4]):
if y[3] < y[0]:
opt[pix, piy] = 5
elif (y[3] - y[0]) / (np.max(y) - np.min(y)) < 0.05:
opt[pix, piy] = 5
end = time.time()
dur = end-start
if dur < 60:
print('processing time: %.2f sec' % dur)
else:
print('processing time: %.2f min' % (dur/60))
print('----------finish classification-----------')
return opt
# ********************** classify tree end **********************
# ********************** main WV3 CLASSIFIER start **********************
def singleclassification(classmap,accepted=True):
if accepted:
classmap[np.where(classmap == 5)] = 1
newcm, clasn = getclassname(classmap)
return newcm,clasn
def multiclassification(fpath, isconvolve=False, th1=0.1, th3=0.05, th4=0.02, mask=None, maskvalue=0, prefix=None, accepted=True):
if prefix is None:
prefix = []
if mask is None:
mask = []
if isconvolve:
pass
# path =
else:
path = fpath
ilist = [i for i in os.listdir(path) if i.find('hdr') > 0]
if path[-1] == '/':
resultpath = path + 'resultmap/'
else:
resultpath = path + '/resultmap/'
if len(prefix) == 0:
prefix = localtime()
for i in ilist:
print('start classifying ' + path + i)
hdr = readhdr(path + i)
if hdr.wvl[0] < 1000:
hdr.wvl = np.array(worldviewvl)
if len(mask) == 0:
mask = mkmask(hdr, maskvalue)
classmap = wvplcl(hdr, th1, th3, th4, mask)
if accepted:
classmap[np.where(classmap == 5)] = 1
mask = []
newcm, clasn = getclassname(classmap)
writebsq(buildimagecube(newcm, []),
i.replace('.hdr', '') + '_classresult_threshold_' + str("%.2f" % th4).replace('.', '_'),
resultpath + prefix + '/', clasn)
# print('finish classifying '+ path+i)
print('end processing')
# ********************** main WV3 CLASSIFIER end **********************
if __name__ == '__main__':
bpath = os.getcwd()
# Path containing your dataset (can be hyperspectral data, example recorded by PRISMA)
# Must add '/' in the end of the string
path =bpath+'/example/'
# This line is for spectral resampling to WV3 configuration
tpath = multispecim(path)
# This line is for the classication using default setting in the paper
# Please change the setting in the function on demands
multiclassification(tpath)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment