-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathpystretcher.py
executable file
·286 lines (236 loc) · 10.5 KB
/
pystretcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
#!/usr/bin/env python
#Internal imports
from pystretch.core import OptParse, Stats, Timer
from pystretch.core.GdalIO import OpenDataSet, create_output
from pystretch.masks import Segment
import pystretch.core.globalarr as glb
from pystretch.core import maskndv
#Debugging imports
#import profile
#Core imports
import multiprocessing as mp
import ctypes
import sys
import time
import gc
import numpy as np
np.seterr(all='ignore')
#External imports
try:
from osgeo import gdal
from osgeo.gdalconst import *
version_num = int(gdal.VersionInfo('VERSION_NUM'))
if version_num <1800 :
print 'ERROR: Python bindings of GDAL version 1.8.0 or later required'
raise
else:
pass
except ImportError:
print "GDAL and the GDAL python bindings must be installed."
raise
try:
import numpy as np
except ImportError:
print "NumPY must be installed."
raise
try:
import scipy
except ImportError:
print "Some functionality will not work without scipy installed."
_gdal_to_ctypes = {1:ctypes.c_short, 2:ctypes.c_int16, 3:ctypes.c_int16,
4:ctypes.c_int32, 5:ctypes.c_int32, 6:ctypes.c_float,
7:ctypes.c_double}
_gdal_to_numpy = {1:np.int16, 2:np.int16, 3:np.int16, 4:np.int32,
5:np.int32, 6:np.float32, 7:np.float32}
_ctypes_to_np = {
ctypes.c_char : np.int8,
ctypes.c_wchar : np.int16,
ctypes.c_byte : np.int8,
ctypes.c_ubyte : np.uint8,
ctypes.c_short : np.int16,
ctypes.c_ushort : np.uint16,
ctypes.c_int : np.int32,
ctypes.c_uint : np.int32,
ctypes.c_long : np.int32,
ctypes.c_ulong : np.int32,
ctypes.c_float : np.float32,
ctypes.c_double : np.float64
}
def segment_image(xsize, ysize, xsegment, ysegment):
"""Function to segment the images into a user defined number of sections
and store the segment dimensions in a tuple.
We assume that the image has the same dimensions, with the same pixel
size in every band. This may not hold true for formats like JP2k."""
if xsegment is None:
xsegment = 1
if ysegment is None:
ysegment = 1
intervalx = xsize / xsegment
intervaly = ysize / ysegment
#Setup to segment the image storing the start values and key into a dictionary.
xstart = 0
ystart = 0
output = []
for y in xrange(0, ysize, intervaly):
if y + intervaly <= ysize:
numberofrows = intervaly
else:
numberofrows = ysize - y
for x in xrange(0, xsize, intervalx):
if x + intervalx <= xsize:
numberofcolumns = intervalx
else:
numberofcolumns = xsize - x
output.append((x,y,numberofcolumns, numberofrows))
return output
def scale(a, b, bandmin, bandmax):
a = float(a)
b = float(b)
glb.sharedarray = (((b - a) * (glb.sharedarray - bandmin)) / (bandmax - bandmin)) + a
def main(args):
starttime = Timer.starttimer()
#Cache thrashing is common when working with large files
# we help alleviate misses by setting a larger than normal cache. 1GB
gdal.SetCacheMax(1073741824)
#Get stretch type
stretch = OptParse.argget_stretch(args)
#Get some info about the machine for mp
cores = args['ncores']
if cores is None:
cores = mp.cpu_count()
#Load the input dataset using the GdalIO class and get / set the output datatype.
dataset = OpenDataSet(args['input'])
raster = dataset.load()
xsize, ysize, nbands, projection, geotransform = dataset.info(raster)
#Get band information
bands = [raster.GetRasterBand(b) for b in range(1, nbands + 1)]
bandstats = [Stats.get_band_stats(b) for b in bands]
b = bands[0]
banddtype = b.DataType
blocksize = b.GetBlockSize()
xblocksize = blocksize[0]
yblocksize = blocksize[1]
output = create_output(args['outputformat'],args['output'],
xsize, ysize, len(bands), projection,
geotransform, gdal.GetDataTypeByName(args['dtype']))
#Intelligently segment the image based upon number of cores and intrinsic block size
if args['byline'] is True:
segments = segment_image(xsize, ysize, 1, ysize)
args['statsper'] = True
elif args['bycolumn'] is True:
segments = segment_image(xsize, ysize, xsize, 1)
args['statsper'] = True
elif args['horizontal_segments'] is not None or args['vertical_segments'] is not None:
#The user is defining the segmentation
segments = segment_image(xsize, ysize, args['vertical_segments'],args['horizontal_segments'])
else:
segments = [(0,0,xsize, ysize)]
carray_dtype = _gdal_to_ctypes[banddtype]
#Preallocate a sharedmem array of the correct size
ctypesxsize, ctypesysize= segments[0][2:]
if args['byline'] is True:
ctypesysize = cores
elif args['bycolumn'] is True:
ctypesxsize = cores
carray = mp.RawArray(carray_dtype, ctypesxsize * ctypesysize)
glb.sharedarray = np.frombuffer(carray,dtype=_gdal_to_numpy[banddtype]).reshape(ctypesysize, ctypesxsize)
pool = mp.Pool(processes=cores, initializer=glb.init, initargs=(glb.sharedarray, ))
#A conscious decision to iterate over the bands in serial - a IO bottleneck anyway
for j,band in enumerate(bands):
stats = bandstats[j]
bandmin = stats['minimum']
bandmax = stats['maximum']
ndv = stats['ndv']
userndv = args['ndv']
args.update(stats)
if args['byline'] is True:
for y in range(0, ysize, cores):
xstart, ystart, intervalx, intervaly = 0, y, xsize, cores
if ystart + intervaly > ysize:
intervaly = ysize - ystart
#print ystart, ystart + intervaly
#print y, ystart, ystart+ intervaly, intervaly
glb.sharedarray[:intervaly, :intervalx] = band.ReadAsArray(xstart, ystart, intervalx, intervaly)
#If the input has an NDV - mask it.
if stats['ndv'] != None:
glb.sharedarray = np.ma.masked_equal(glb.sharedarray, stats['ndv'], copy=False)
mask = np.ma.getmask(glb.sharedarray)
#if args['statsper'] is True:
#args.update(Stats.get_array_stats(glb.sharedarray, stretch))
for i in range(cores):
res = pool.apply(stretch, args=(slice(i, i+1), args))
if args['ndv'] != None:
glb.sharedarray[glb.sharedarray == ndv] = args['ndv']
output.GetRasterBand(j+1).SetNoDataValue(float(userndv))
if args['scale'] is not None:
#Scale the data before writing to disk
scale(args['scale'][0], args['scale'][1], bandmin, bandmax)
output.GetRasterBand(j+1).WriteArray(glb.sharedarray[:intervaly, :intervalx], xstart,ystart)
if args['quiet']:
print "Processed {} or {} lines \r".format(y, ysize),
sys.stdout.flush()
elif args['bycolumn'] is True:
for x in range(0, xsize, cores):
xstart, ystart, intervalx, intervaly = x, 0, cores, ysize
if xstart + intervalx > xsize:
intervalx = xsize - xstart
glb.sharedarray[:intervaly, :intervalx] = band.ReadAsArray(xstart, ystart, intervalx, intervaly)
#If the input has an NDV - mask it.
if stats['ndv'] != None:
glb.sharedarray = np.ma.masked_equal(glb.sharedarray, stats['ndv'], copy=False)
mask = np.ma.getmask(glb.sharedarray)
if args['statsper'] is True:
args.update(Stats.get_array_stats(glb.sharedarray, stretch))
for i in range(cores):
res = pool.apply(stretch, args=(slice(i, i+1), args))
if args['ndv'] != None:
glb.sharedarray[glb.sharedarray == ndv] = args['ndv']
output.GetRasterBand(j+1).SetNoDataValue(float(userndv))
if args['scale'] is not None:
scale(args['scale'][0], args['scale'][1], bandmin, bandmax)
output.GetRasterBand(j+1).WriteArray(glb.sharedarray[:intervaly, :intervalx], xstart,ystart)
if args['quiet']:
print "Processed {} or {} lines \r".format(x, xsize),
sys.stdout.flush()
#If not processing line by line, distirbuted the block over availabel cores
else:
for i, chunk in enumerate(segments):
xstart, ystart, intervalx, intervaly = chunk
#Read the array into the buffer
glb.sharedarray[:intervaly, :intervalx] = band.ReadAsArray(xstart, ystart, intervalx, intervaly)
#If the input has an NDV - mask it.
if stats['ndv'] != None:
glb.sharedarray = np.ma.masked_equal(glb.sharedarray, stats['ndv'], copy=False)
mask = np.ma.getmask(glb.sharedarray)
if args['statsper'] is True:
args.update(Stats.get_array_stats(glb.sharedarray, stretch))
#Determine the decomposition for each core
step = intervaly // cores
starts = range(0, intervaly+1, step)
stops = starts[1:]
stops.append(intervaly+1)
offsets = zip(starts, stops)
for o in offsets:
res = pool.apply(stretch, args=(slice(o[0], o[1]), args))
if args['ndv'] != None:
glb.sharedarray[glb.sharedarray == ndv] = args['ndv']
output.GetRasterBand(j+1).SetNoDataValue(float(userndv))
if args['scale'] is not None:
#Scale the data before writing to disk
scale(args['scale'][0], args['scale'][1], bandmin, bandmax)
output.GetRasterBand(j+1).WriteArray(glb.sharedarray[:intervaly, :intervalx], xstart,ystart)
Timer.totaltime(starttime)
#Close up
dataset = None
output = None
pool.close()
pool.join()
def init(shared_arr_):
global sharedarray
sharedarray = shared_arr_ # must be inhereted, not passed as an argument global array
if __name__ == '__main__':
mp.freeze_support()
#If the script is run via the command line we start here, otherwise start in main.
#(options, args) = OptParse.parse_arguments()
#gdal.SetConfigOption('CPL_DEBUG', 'ON')
main(OptParse.argparse_arguments())