Commit a3f4cc3e authored by Hoek, Steven's avatar Hoek, Steven
Browse files

These classes are now equipped to also handle multiband rasters

parent bad28b82
...@@ -15,6 +15,7 @@ class InMemoryRaster(Raster, GridEnvelope2D): ...@@ -15,6 +15,7 @@ class InMemoryRaster(Raster, GridEnvelope2D):
data = None data = None
__datatype = None __datatype = None
__open = False __open = False
__nbands = 1 # default
dataformat = 'f' dataformat = 'f'
def __init__(self, filepath, data=None, *datatype): def __init__(self, filepath, data=None, *datatype):
...@@ -32,28 +33,46 @@ class InMemoryRaster(Raster, GridEnvelope2D): ...@@ -32,28 +33,46 @@ class InMemoryRaster(Raster, GridEnvelope2D):
self.__datatype = const.FLOAT; self.__datatype = const.FLOAT;
self.dataformat = 'f' self.dataformat = 'f'
def open(self, mode, ncols=1, nrows=1, xll=0.0, yll=0.0, cellsize=1.0, nodatavalue=-9999.0): def open(self, mode, ncols=1, nrows=1, nbands=1, xll=0.0, yll=0.0, cellsize=1.0, nodatavalue=-9999.0):
super(InMemoryRaster, self).open(mode); super(InMemoryRaster, self).open(mode);
self.__open = True self.__open = True
if self.__datatype == const.INTEGER: if self.__datatype == const.INTEGER:
dtype = np.int dtype = np.int
else: else:
dtype = np.float dtype = np.float
self.__nbands = nbands
if mode[0] == 'w': if mode[0] == 'w':
# Writing mode
if self.data is None: if self.data is None:
self.data = np.zeros((nrows*ncols), dtype=dtype) self.data = np.zeros((nbands*nrows*ncols), dtype=dtype)
else: else:
# Establish height and width
if nbands == 1:
height = self.data.shape[0]
width = self.data.shape[1]
else:
height = self.data.shape[1]
width = self.data.shape[2]
# Now check that the data already has dimensions as indicated
errmsg = "Shape of input data does not match given dimensions!" errmsg = "Shape of input data does not match given dimensions!"
if len(self.data) != nrows: if height != nrows:
raise Exception(errmsg)
if (height > 0) and (width != ncols):
raise Exception(errmsg) raise Exception(errmsg)
else:
if (len(self.data) > 0) and (len(self.data[0]) != ncols):
raise Exception(errmsg)
else: else:
# Reading mode
if self.data == None: raise Exception("Memory was not initialised!") if self.data == None: raise Exception("Memory was not initialised!")
self.data = np.array(self.data, dtype=dtype) self.data = np.array(self.data, dtype=dtype)
# Make sure that the data have the right shape
self.data = self.data.flatten() self.data = self.data.flatten()
self.data.shape = (nrows, ncols) if nbands == 1:
self.data.shape = (nrows, ncols)
else:
self.data.shape = (nbands, nrows, ncols)
# We need to also initialise the grid envelope
self.xll = xll self.xll = xll
self.yll = yll self.yll = yll
self.dx = cellsize self.dx = cellsize
...@@ -68,16 +87,31 @@ class InMemoryRaster(Raster, GridEnvelope2D): ...@@ -68,16 +87,31 @@ class InMemoryRaster(Raster, GridEnvelope2D):
self.currow += 1; self.currow += 1;
if (self.currow > self.nrows): raise StopIteration; if (self.currow > self.nrows): raise StopIteration;
if parseLine: if parseLine:
return self.data[self.currow - 1, :] if self.__nbands == 1:
return self.data[self.currow - 1, :]
else:
return self.data[:, self.currow - 1, :]
else: else:
return None return None
def writenext(self, sequence_with_data): def writenext(self, sequence_with_data):
# Initialise
self.currow += 1
# Check a few things
if not self.__open: raise Exception("Not yet fully initialised!") if not self.__open: raise Exception("Not yet fully initialised!")
self.currow += 1; if (self.currow > self.nrows): raise StopIteration
if (self.currow > self.nrows): raise StopIteration; if self.__nbands == 1:
if len(sequence_with_data) != self.ncols: width = len(sequence_with_data)
else:
width = sequence_with_data.shape[1]
if width != self.ncols:
raise Exception("Attempt to assign line of wrong length!") raise Exception("Attempt to assign line of wrong length!")
self.data[self.currow - 1, :] = sequence_with_data
# Assign the input line to the internal memory structure
if self.__nbands == 1:
self.data[self.currow - 1, :] = sequence_with_data
else:
self.data[:, self.currow - 1, :] = sequence_with_data
return True return True
\ No newline at end of file
# Copyright (c) 2004-2021 WUR, Wageningen # Copyright (c) 2004-2021 WUR, Wageningen
import os.path import os.path
import sys
sys.path.append("../lmgeo")
from lmgeo.formats.raster import Raster from lmgeo.formats.raster import Raster
from lmgeo.formats.gridenvelope2d import GridEnvelope2D from lmgeo.formats.gridenvelope2d import GridEnvelope2D
from lmgeo.formats.const import constants as const from lmgeo.formats.const import constants as const
...@@ -25,11 +27,13 @@ class RioRaster(Raster, GridEnvelope2D): ...@@ -25,11 +27,13 @@ class RioRaster(Raster, GridEnvelope2D):
with an interface similar to the classes found in the formats folder of package lmgeo with an interface similar to the classes found in the formats folder of package lmgeo
''' '''
data = None data = None
nbands = 1 # default
datatype = 'i' datatype = 'i'
__numpy_type = np.int __numpy_type = np.int
__rows_per_strip = 128 # default for reading __rows_per_strip = 128 # default for reading
__number_of_strips = 1 __number_of_strips = 1
__currow = -1 __currow = -1
__crs = CRS.from_epsg(4326) # default
def __init__(self, filepath, *datatype): def __init__(self, filepath, *datatype):
# Initialise # Initialise
...@@ -41,13 +45,13 @@ class RioRaster(Raster, GridEnvelope2D): ...@@ -41,13 +45,13 @@ class RioRaster(Raster, GridEnvelope2D):
print('File path cannot be an empty string (method __init__).') print('File path cannot be an empty string (method __init__).')
self.name = os.path.basename(filepath); self.name = os.path.basename(filepath);
self.folder = os.path.dirname(filepath); self.folder = os.path.dirname(filepath);
self.datatype = datatype[0] self.datatype = datatype[0]
# overrides same method of Raster # overrides same method of Raster
def getWorldFileExt(self): def getWorldFileExt(self):
return 'tfw' return 'tfw'
def open(self, mode, ncols=1, nrows=1, xll=0, yll=0, cellsize=100, nodatavalue=-9999.0): def open(self, mode, ncols=1, nrows=1, nbands=1, xll=0, yll=0, cellsize=100, nodatavalue=-9999.0):
# Initialise # Initialise
fn = os.path.join(self.folder, self.name) fn = os.path.join(self.folder, self.name)
...@@ -86,6 +90,7 @@ class RioRaster(Raster, GridEnvelope2D): ...@@ -86,6 +90,7 @@ class RioRaster(Raster, GridEnvelope2D):
# Idea is to write 16 lines at once # Idea is to write 16 lines at once
self.ncols = ncols self.ncols = ncols
self.nrows = nrows self.nrows = nrows
self.nbands = nbands
self.xll = xll self.xll = xll
self.yll = yll self.yll = yll
self.dx = cellsize self.dx = cellsize
...@@ -94,14 +99,17 @@ class RioRaster(Raster, GridEnvelope2D): ...@@ -94,14 +99,17 @@ class RioRaster(Raster, GridEnvelope2D):
yul = self.yll + self.nrows * self.dy yul = self.yll + self.nrows * self.dy
self.__rows_per_strip = 16 # default for writing self.__rows_per_strip = 16 # default for writing
numblocks = (1 + self.ncols // 16) numblocks = (1 + self.ncols // 16)
self.data = np.empty((self.__rows_per_strip, ncols), dtype=dtype) if self.nbands == 1:
self.data = np.empty((self.__rows_per_strip, ncols), dtype=dtype)
else:
self.data = np.empty((self.nbands, self.__rows_per_strip, ncols), dtype=dtype)
self.data.fill(nodatavalue) self.data.fill(nodatavalue)
# Prepare a profile and then open the TIFF file for writing # Prepare a profile and then open the TIFF file for writing
# TODO: make this suitable for more than 1 band # TODO: make this suitable for more other coordinate reference systems
with rio.Env(): with rio.Env():
profile = {'driver': driver, 'dtype': dtype, 'nodata': nodatavalue, 'width': ncols, 'height': nrows, \ profile = {'driver': driver, 'dtype': dtype, 'nodata': nodatavalue, 'width': ncols, 'height': nrows, \
'count': 1, 'crs': CRS.from_epsg(4326), 'transform': Affine(cellsize, 0.0, xll, 0.0, -1*cellsize, yul), \ 'count': nbands, 'crs': self.__crs, 'transform': Affine(cellsize, 0.0, xll, 0.0, -1*cellsize, yul), \
'blockxsize': numblocks * 16, 'blockysize': self.__rows_per_strip , 'tiled': False, 'compress': 'lzw', \ 'blockxsize': numblocks * 16, 'blockysize': self.__rows_per_strip , 'tiled': False, 'compress': 'lzw', \
'interleave': 'band'} 'interleave': 'band'}
src = rio.open(fn, 'w', **profile) src = rio.open(fn, 'w', **profile)
...@@ -118,6 +126,10 @@ class RioRaster(Raster, GridEnvelope2D): ...@@ -118,6 +126,10 @@ class RioRaster(Raster, GridEnvelope2D):
self.ncols = src.meta["width"] self.ncols = src.meta["width"]
if "height" in src.meta: if "height" in src.meta:
self.nrows = src.meta["height"] self.nrows = src.meta["height"]
if "count" in src.meta:
self.nbands = src.meta["count"]
if "crs" in src.meta:
self.__crs = src.meta["crs"]
if "nodata" in src.meta: if "nodata" in src.meta:
self.nodatavalue = src.meta["nodata"] self.nodatavalue = src.meta["nodata"]
if ("transform" in src.meta) and (not src.meta["transform"] is None): if ("transform" in src.meta) and (not src.meta["transform"] is None):
...@@ -182,10 +194,18 @@ class RioRaster(Raster, GridEnvelope2D): ...@@ -182,10 +194,18 @@ class RioRaster(Raster, GridEnvelope2D):
if row_in_strip == 0: if row_in_strip == 0:
src = self.datafile src = self.datafile
nrows = min(self.__rows_per_strip, self.nrows - self.__currow) nrows = min(self.__rows_per_strip, self.nrows - self.__currow)
self.data = src.read(1, window=Window(0, self.__currow, self.ncols, nrows))
# We read only part of the image at the time
if self.nbands == 1:
self.data = src.read(1, window=Window(0, self.__currow, self.ncols, nrows))
else:
self.data = src.read(window=Window(0, self.__currow, self.ncols, nrows))
# Extract the next row # Extract the next row
result = self.data[row_in_strip, :] if self.nbands == 1:
result = self.data[row_in_strip, :]
else:
result = self.data[:, row_in_strip, :]
except StopIteration: except StopIteration:
raise StopIteration; raise StopIteration;
except Exception as e: except Exception as e:
...@@ -200,26 +220,47 @@ class RioRaster(Raster, GridEnvelope2D): ...@@ -200,26 +220,47 @@ class RioRaster(Raster, GridEnvelope2D):
curstrip = self.__currow // self.__rows_per_strip # zero-based! curstrip = self.__currow // self.__rows_per_strip # zero-based!
# Now let's see what we can do # Now let's see what we can do
A = (self.__currow != 0) and (row_in_strip == 0) # previous strip is ready for writing A = (self.__currow != 0) and (row_in_strip == 0) # previous strip is ready: can be written to disk
B = (self.__currow == self.nrows - 1) # last strip still has to be written B = (self.__currow == self.nrows - 1) # very last strip still has to be written
if A or B: if A:
if A: # Prepare to write the complete strip that is ready
ioffset = (-1 + curstrip) * self.__rows_per_strip ioffset = (-1 + curstrip) * self.__rows_per_strip
height = self.__rows_per_strip height = self.__rows_per_strip
else:
ioffset = curstrip * self.__rows_per_strip # Write the data to file and prepare for the next loop
height = self.nrows - (curstrip * self.__rows_per_strip) self.__writestrip(ioffset, height)
self.data[row_in_strip, :] = sequence_with_data
self.data = self.data[:height, :]
mywindow = Window(col_off=0, row_off=ioffset, width=self.ncols, height=height)
self.datafile.write(self.data, window=mywindow, indexes=1)
self.data.fill(self.nodatavalue) self.data.fill(self.nodatavalue)
# Always update the data if B:
if not B: self.data[row_in_strip, :] = sequence_with_data # Prepare the very last strip of the raster
ioffset = curstrip * self.__rows_per_strip
height = self.nrows - (curstrip * self.__rows_per_strip)
# Now we assign the last data and then we write
if self.nbands == 1:
self.data[row_in_strip, :] = sequence_with_data
self.data = self.data[0:height, :]
else:
self.data[:, row_in_strip, :] = sequence_with_data
self.data = self.data[:, 0:height, :]
self.__writestrip(ioffset, height)
else:
# We have not yet reached the end of the raster - update the data
if self.nbands == 1:
self.data[row_in_strip, :] = sequence_with_data
else:
self.data[:, row_in_strip, :] = sequence_with_data
def __writestrip(self, ioffset, height):
# Write the data to file and prepare for the next loop
mywindow = Window(col_off=0, row_off=ioffset, width=self.ncols, height=height)
if self.nbands == 1:
self.datafile.write(self.data, window=mywindow, indexes=1)
else:
if height == 1: self.data = np.reshape(self.data, (self.nbands, 1, self.ncols))
self.datafile.write(self.data, window=mywindow)
self.data.fill(self.nodatavalue)
def close(self): def close(self):
try: try:
if self.datafile: if self.datafile:
...@@ -234,4 +275,11 @@ class RioRaster(Raster, GridEnvelope2D): ...@@ -234,4 +275,11 @@ class RioRaster(Raster, GridEnvelope2D):
def reset(self): def reset(self):
self.__currow = -1 self.__currow = -1
@property
\ No newline at end of file def crs(self):
return self.__crs
@crs.setter
def crs(self, crs):
# TODO: differentiate dx and dy!
self.__crs = crs
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment