.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/ml/plot_k_neighbors_classification.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_ml_plot_k_neighbors_classification.py: K-nearest neighbors classification ================================== Shows the usage of the k-nearest neighbors classifier. .. GENERATED FROM PYTHON SOURCE LINES 7-11 .. code-block:: Python # Author: Pablo Marcos Manchón # License: MIT .. GENERATED FROM PYTHON SOURCE LINES 12-22 In this example we are going to show the usage of the K-nearest neighbors classifier in their functional version, which is a extension of the multivariate one, but using functional metrics. Firstly, we are going to fetch a functional dataset, such as the Berkeley Growth Study. This dataset contains the height of several boys and girls measured until the 18 years of age. We will try to predict sex from their growth curves. The following figure shows the growth curves grouped by sex. .. GENERATED FROM PYTHON SOURCE LINES 23-37 .. code-block:: Python import matplotlib.pyplot as plt from skfda.datasets import fetch_growth X_df, y_df = fetch_growth(return_X_y=True, as_frame=True) X = X_df.iloc[:, 0].array target = y_df.array y = target.codes # Plot samples grouped by sex X.plot(group=target.codes, group_names=target.categories) plt.show() .. image-sg:: /auto_examples/ml/images/sphx_glr_plot_k_neighbors_classification_001.png :alt: Berkeley Growth Study :srcset: /auto_examples/ml/images/sphx_glr_plot_k_neighbors_classification_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 46-48 The class labels are stored in an array. Zeros represent male samples while ones represent female samples. .. GENERATED FROM PYTHON SOURCE LINES 49-52 .. code-block:: Python print(y) .. rst-class:: sphx-glr-script-out .. code-block:: none [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] .. GENERATED FROM PYTHON SOURCE LINES 53-59 We can split the dataset using the sklearn function :func:`~sklearn.model_selection.train_test_split`. The function will return two :class:`~skfda.representation.grid.FDataGrid`'s, ``X_train`` and ``X_test`` with the corresponding partitions, and arrays with their class labels. .. GENERATED FROM PYTHON SOURCE LINES 60-72 .. code-block:: Python from sklearn.model_selection import GridSearchCV, train_test_split X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.25, stratify=y, random_state=0, ) .. GENERATED FROM PYTHON SOURCE LINES 73-80 We will fit the classifier :class:`~skfda.ml.classification.KNeighborsClassifier` with the training partition. This classifier works exactly like the sklearn multivariate classifier :class:`~sklearn.neighbors.KNeighborsClassifier`, but it's input is a :class:`~skfda.representation.grid.FDataGrid` with functional observations instead of an array with multivariate data. .. GENERATED FROM PYTHON SOURCE LINES 81-87 .. code-block:: Python from skfda.ml.classification import KNeighborsClassifier knn = KNeighborsClassifier(n_neighbors=5) knn.fit(X_train, y_train) .. raw:: html
KNeighborsClassifier()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 88-99 Once it is fitted, we can predict labels for the test samples. To predict the label of a test sample, the classifier will calculate the k-nearest neighbors and will assign the class shared by most of those k neighbors. In this case, we have set the number of neighbors to 5 (:math:`k=5`). By default, it will use the :math:`\mathbb{L}^2` distance between functions, to determine the neighborhood of a sample. However, it can be used with any of the functional metrics described in :doc:`/modules/misc/metrics`. .. GENERATED FROM PYTHON SOURCE LINES 100-104 .. code-block:: Python pred = knn.predict(X_test) print(pred) .. rst-class:: sphx-glr-script-out .. code-block:: none [0 0 1 0 1 1 1 0 0 0 0 1 1 0 0 0 0 1 1 1 1 1 1 1] .. GENERATED FROM PYTHON SOURCE LINES 105-108 The :func:`~skfda.ml.classification.KNeighborsClassifier.score` method allows us to calculate the mean accuracy for the test data. In this case we obtained around 96% of accuracy. .. GENERATED FROM PYTHON SOURCE LINES 109-113 .. code-block:: Python score = knn.score(X_test, y_test) print(score) .. rst-class:: sphx-glr-script-out .. code-block:: none 0.9583333333333334 .. GENERATED FROM PYTHON SOURCE LINES 114-118 We can also estimate the probability of membership to the predicted class using :func:`~skfda.ml.classification.KNeighborsClassifier.predict_proba`, which will return an array with the probabilities of the classes, in lexicographic order, for each test sample. .. GENERATED FROM PYTHON SOURCE LINES 119-124 .. code-block:: Python probs = knn.predict_proba(X_test[:5]) # Predict first 5 samples print(probs) .. rst-class:: sphx-glr-script-out .. code-block:: none [[1. 0. ] [0.6 0.4] [0. 1. ] [1. 0. ] [0. 1. ]] .. GENERATED FROM PYTHON SOURCE LINES 125-130 We can use the sklearn :class:`~sklearn.model_selection.GridSearchCV` to perform a grid search to select the best hyperparams, using cross-validation. In this case, we will vary the number of neighbors between 1 and 17. .. GENERATED FROM PYTHON SOURCE LINES 131-147 .. code-block:: Python # Only odd numbers, to prevent ties param_grid = {"n_neighbors": range(1, 18, 2)} knn = KNeighborsClassifier() # Perform grid search with cross-validation gscv = GridSearchCV(knn, param_grid, cv=5) gscv.fit(X_train, y_train) print("Best params:", gscv.best_params_) print("Best cross-validation score:", gscv.best_score_) .. rst-class:: sphx-glr-script-out .. code-block:: none Best params: {'n_neighbors': 11} Best cross-validation score: 0.9571428571428573 .. GENERATED FROM PYTHON SOURCE LINES 148-150 We have obtained the greatest mean accuracy using 11 neighbors. The following figure shows the score depending on the number of neighbors. .. GENERATED FROM PYTHON SOURCE LINES 151-161 .. code-block:: Python fig, ax = plt.subplots() ax.bar(param_grid["n_neighbors"], gscv.cv_results_["mean_test_score"]) ax.set_xticks(param_grid["n_neighbors"]) ax.set_ylabel("Number of Neighbors") ax.set_xlabel("Cross-validation score") ax.set_ylim((0.9, 1)) plt.show() .. image-sg:: /auto_examples/ml/images/sphx_glr_plot_k_neighbors_classification_002.png :alt: plot k neighbors classification :srcset: /auto_examples/ml/images/sphx_glr_plot_k_neighbors_classification_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 162-168 By default, after performing the cross validation, the classifier will be fitted to the whole training data provided in the call to :func:`~skfda.ml.classification.KNeighborsClassifier.fit`. Therefore, to check the accuracy of the classifier for the number of neighbors selected (11), we can simply call the :func:`~sklearn.model_selection.GridSearchCV.score` method. .. GENERATED FROM PYTHON SOURCE LINES 169-173 .. code-block:: Python score = gscv.score(X_test, y_test) print(score) .. rst-class:: sphx-glr-script-out .. code-block:: none 1.0 .. GENERATED FROM PYTHON SOURCE LINES 174-176 This classifier can be used with multivariate functional data, as surfaces or curves in :math:`\mathbb{R}^N`, if the metric supports it too. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.252 seconds) .. _sphx_glr_download_auto_examples_ml_plot_k_neighbors_classification.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/GAA-UAM/scikit-fda/develop?filepath=examples/ml/plot_k_neighbors_classification.py :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_k_neighbors_classification.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_k_neighbors_classification.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_k_neighbors_classification.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_