Skip to content

Commit

Permalink
Simplified classification code
Browse files Browse the repository at this point in the history
  • Loading branch information
davecom committed May 6, 2024
1 parent 12a9e48 commit 1209dfb
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: Python package
name: Python Unit Tests, Lint, and Type Checks

on:
push:
Expand Down Expand Up @@ -39,3 +39,5 @@ jobs:
- name: Run all unit tests in the /tests directory
run: |
python -m unittest discover -s tests
- name: Check Type Hints with Pyright
uses: jakebailey/[email protected]
11 changes: 3 additions & 8 deletions KNN/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
import csv
from typing import Protocol, Self
from collections import Counter
import numpy as np


Expand Down Expand Up @@ -44,19 +45,13 @@ def _read_csv(self, file_path: str, has_header: bool) -> None:

# Find the k nearest neighbors of a given data point based on the distance method
def nearest(self, k: int, data_point: DP) -> list[DP]:
return sorted(self.data_points, key=lambda other: data_point.distance(other))[:k]
return sorted(self.data_points, key=data_point.distance)[:k]

# Classify a data point based on the k nearest neighbors
# Choose the kind with the most neighbors and return it
def classify(self, k: int, data_point: DP) -> str:
neighbors = self.nearest(k, data_point)
kinds = {}
for neighbor in neighbors:
if neighbor.kind in kinds:
kinds[neighbor.kind] += 1
else:
kinds[neighbor.kind] = 1
return max(kinds, key=kinds.get) # type: ignore
return Counter(neighbor.kind for neighbor in neighbors).most_common(1)[0][0]

# Predict a property of a data point based on the k nearest neighbors
# Find the average of that property from the neighbors and return it
Expand Down

0 comments on commit 1209dfb

Please sign in to comment.