Commit 0bafa7a4 authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

Added random forest classifier to classification algorithms + tests.

parent b4593aad
Pipeline #3750 passed with stage
in 20 minutes and 12 seconds
......@@ -8,6 +8,7 @@ from typing import Union, List, Tuple # noqa F401 # flake8 issue
from multiprocessing import Pool
from tqdm import tqdm
from sklearn.neighbors import KNeighborsClassifier, NearestCentroid
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import MaxAbsScaler
from geoarray import GeoArray
from py_tools_ds.numeric.array import get_array_tilebounds
......@@ -238,6 +239,31 @@ class SID_Classifier(_ImageClassifier):
return np.sum(p * np.log(p / q) + q * np.log(q / p), axis=axis)
class RF_Classifier(_ImageClassifier):
"""Random forest classifier."""
def __init__(self, train_spectra, train_labels, CPUs=1, n_estimators=100, max_depth=2, random_state=0, **kw):
# type: (np.ndarray, Union[np.ndarray, List[int]], Union[int, None], int, int, int, dict) -> None
# if CPUs is None or CPUs > 1:
# CPUs = 1 # The NearestCentroid seems to parallelize automatically. So using multiprocessing is slower.
super(RF_Classifier, self).__init__(train_spectra, train_labels, CPUs=CPUs)
self.clf_name = 'random forest'
self.clf = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth, random_state=random_state,
n_jobs=1, **kw), train_labels)
def _predict(self, tilepos):
assert global_shared_im2classify is not None
(rS, rE), (cS, cE) = tilepos
tileimdata = global_shared_im2classify[rS: rE + 1, cS: cE + 1, :]
spectra = tileimdata.reshape((tileimdata.shape[0] * tileimdata.shape[1], tileimdata.shape[2]))
return tilepos, self.clf.predict(spectra).reshape(*tileimdata.shape[:2])
def classify_image(image, train_spectra, train_labels, classif_alg,
kNN_n_neighbors=10, in_nodataVal=None, cmap_nodataVal=None, tiledims=(1000, 1000), CPUs=None):
# type: (Union[np.ndarray, GeoArray], np.ndarray, Union[np.ndarray, List[int]], str, int, ...) -> GeoArray
......@@ -252,6 +278,7 @@ def classify_image(image, train_spectra, train_labels, classif_alg,
'kNN': k-nearest-neighbour
'SAM': spectral angle mapping
'SID': spectral information divergence
'RF': random forest
:param kNN_n_neighbors: The number of neighbors to be considered in case 'classif_alg' is set to
'kNN'. Otherwise, this parameter is ignored.
:param in_nodataVal:
......@@ -281,8 +308,14 @@ def classify_image(image, train_spectra, train_labels, classif_alg,
elif classif_alg == 'RF':
clf = RF_Classifier(
raise NotImplementedError("Currently only the methods 'kNN', 'MinDist', 'SAM' and 'SID' are implemented.")
raise NotImplementedError("Currently only the methods 'kNN', 'MinDist', 'SAM', 'SID' and 'RF' are implemented.")
cmap = clf.classify(image, in_nodataVal=in_nodataVal, cmap_nodataVal=cmap_nodataVal, tiledims=tiledims)
......@@ -17,10 +17,8 @@ import numpy as np
from geoarray import GeoArray
from gms_preprocessing import set_config
from gms_preprocessing.algorithms.classification import MinimumDistance_Classifier
from gms_preprocessing.algorithms.classification import kNN_Classifier
from gms_preprocessing.algorithms.classification import SAM_Classifier
from gms_preprocessing.algorithms.classification import SID_Classifier
from gms_preprocessing.algorithms.classification import \
MinimumDistance_Classifier, kNN_Classifier, SAM_Classifier, SID_Classifier, RF_Classifier
from . import db_host
......@@ -98,3 +96,18 @@ class Test_SID_Classifier(unittest.TestCase):
self.assertEqual(cmap_mp.shape, (1010, 1010))
self.assertTrue(np.array_equal(cmap_sp, cmap_mp))
class Test_RF_Classifier(unittest.TestCase):
def test_classify(self):
RFC = RF_Classifier(cluster_centers, cluster_labels, CPUs=1)
cmap_sp = RFC.classify(test_gA, in_nodataVal=-9999, tiledims=(400, 200))
self.assertIsInstance(cmap_sp, np.ndarray)
self.assertEqual(cmap_sp.shape, (1010, 1010))
RFC = RF_Classifier(cluster_centers, cluster_labels, CPUs=None)
cmap_mp = RFC.classify(test_gA, in_nodataVal=-9999, tiledims=(400, 200))
self.assertIsInstance(cmap_mp, np.ndarray)
self.assertEqual(cmap_mp.shape, (1010, 1010))
self.assertTrue(np.array_equal(cmap_sp, cmap_mp))
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment