Commit 2d52a2b2 authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

Moved imports of scikit-learn to function/class level to avoid static TLS...


Moved imports of scikit-learn to function/class level to avoid static TLS ImportError. Updated version info.
Signed-off-by: Daniel Scheffler's avatarDaniel Scheffler <danschef@gfz-potsdam.de>
parent 29d31b8d
Pipeline #12860 passed with stages
in 2 minutes and 50 seconds
......@@ -25,8 +25,6 @@
import numpy as np
from typing import Union, List # noqa F401 # flake8 issue
from sklearn.neighbors import KNeighborsClassifier as _KNeighborsClassifier
from sklearn.neighbors import NearestCentroid as _NearestCentroid
from geoarray import GeoArray
from .._baseclasses import _ImageClassifier, _kNN_ImageClassifier
......@@ -43,6 +41,8 @@ class MinimumDistance_Classifier(_ImageClassifier):
def __init__(self, train_spectra, train_labels, CPUs=1, **kwargs):
# type: (np.ndarray, Union[np.ndarray, List[int]], Union[int, None], dict) -> None
from sklearn.neighbors import NearestCentroid as _NearestCentroid # avoids static TLS errors here
super(MinimumDistance_Classifier, self).__init__(train_spectra, train_labels, CPUs=CPUs)
self.clf_name = 'minimum distance (nearest centroid)'
......@@ -130,6 +130,8 @@ class kNN_MinimumDistance_Classifier(MinimumDistance_Classifier, _kNN_ImageClass
class kNN_Classifier(_ImageClassifier):
def __init__(self, train_spectra, train_labels, CPUs=1, **kwargs):
# type: (np.ndarray, Union[np.ndarray, List[int]], Union[int, None], dict) -> None
from sklearn.neighbors import KNeighborsClassifier as _KNeighborsClassifier # avoids static TLS errors here
super(kNN_Classifier, self).__init__(train_spectra, train_labels, CPUs=CPUs)
self.clf_name = 'k-nearest neighbour (kNN)'
......
......@@ -25,7 +25,6 @@
import numpy as np
from typing import Union, List # noqa F401 # flake8 issue
from sklearn.ensemble import RandomForestClassifier as _RandomForestClassifier
from geoarray import GeoArray
from .._baseclasses import _ImageClassifier
......@@ -36,6 +35,7 @@ class RF_Classifier(_ImageClassifier):
"""Random forest classifier."""
def __init__(self, train_spectra, train_labels, CPUs=1, **kwargs):
# type: (np.ndarray, Union[np.ndarray, List[int]], Union[int, None], dict) -> None
from sklearn.ensemble import RandomForestClassifier as _RandomForestClassifier # avoids static TLS errors here
# if CPUs is None or CPUs > 1:
# CPUs = 1 # The NearestCentroid seems to parallelize automatically. So using multiprocessing is slower.
......
......@@ -25,12 +25,13 @@
import numpy as np
from typing import Union, Tuple # noqa F401 # flake8 issue
from sklearn.preprocessing import MaxAbsScaler
from geoarray import GeoArray
def normalize_endmembers_image(endmembers, image):
# type: (np.ndarray, np.ndarray) -> Tuple[np.ndarray, np.ndarray]
from sklearn.preprocessing import MaxAbsScaler # avoids static TLS errors here
em = endmembers.astype(np.float)
im = image.astype(np.float)
......
......@@ -22,5 +22,5 @@
# with this program. If not, see <http://www.gnu.org/licenses/>.
__version__ = '0.2.4'
__versionalias__ = '20200915.01'
__version__ = '0.2.5'
__versionalias__ = '20200924.01'
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment