Skip to content

Commit

Permalink
record Neighbor
Browse files Browse the repository at this point in the history
  • Loading branch information
haifengl committed Dec 16, 2024
1 parent 87a7934 commit aa2bfb9
Show file tree
Hide file tree
Showing 15 changed files with 227 additions and 148 deletions.
157 changes: 130 additions & 27 deletions base/src/main/java/smile/graph/NearestNeighborGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,44 @@ public static NearestNeighborGraph of(double[][] data, int k) {
return of(data, MathEx::distance, k);
}

/**
* Returns the largest connected component of a nearest neighbor graph.
*
* @param digraph create a directed graph if true.
* @return the largest connected component.
*/
public NearestNeighborGraph largest(boolean digraph) {
AdjacencyList graph = graph(digraph);
int[][] cc = graph.bfcc();
if (cc.length == 1) {
return this;
} else {
int[] index = Arrays.stream(cc)
.max(Comparator.comparing(a -> a.length))
.orElseThrow(NoSuchElementException::new);
logger.info("{} connected components, largest one has {} samples.", cc.length, index.length);

int n = index.length;
int k = neighbors[0].length;

int[] reverseIndex = new int[neighbors.length];
for (int i = 0; i < n; i++) {
reverseIndex[index[i]] = i;
}

int[][] nearest = new int[n][k];
double[][] dist = new double[n][k];
for (int i = 0; i < n; i++) {
dist[i] = distances[index[i]];
int[] ni = neighbors[index[i]];
for (int j = 0; j < k; j++) {
nearest[i][j] = reverseIndex[ni[j]];
}
}
return new NearestNeighborGraph(nearest, dist, index);
}
}

private static class Neighbor implements Comparable<Neighbor> {
public int index;
public double distance;
Expand Down Expand Up @@ -98,6 +136,7 @@ public static <T> NearestNeighborGraph of(T[] data, Distance<T> distance, int k)
if (k < 2) {
throw new IllegalArgumentException("k must be greater than 1: " + k);
}

int n = data.length;
int[][] neighbors = new int[n][k];
double[][] distances = new double[n][k];
Expand Down Expand Up @@ -129,40 +168,104 @@ public static <T> NearestNeighborGraph of(T[] data, Distance<T> distance, int k)
}

/**
* Returns the largest connected component of a nearest neighbor graph.
* Creates an approximate nearest neighbor graph with Euclidean distance.
*
* @param digraph create a directed graph if true.
* @return the largest connected component.
* @param data the dataset.
* @param k k-nearest neighbor.
* @return approximate k-nearest neighbor graph.
*/
public NearestNeighborGraph largest(boolean digraph) {
AdjacencyList graph = graph(digraph);
int[][] cc = graph.bfcc();
if (cc.length == 1) {
return this;
} else {
int[] index = Arrays.stream(cc)
.max(Comparator.comparing(a -> a.length))
.orElseThrow(NoSuchElementException::new);
logger.info("{} connected components, largest one has {} samples.", cc.length, index.length);
/*public static NearestNeighborGraph descent(double[][] data, int k) {
return descent(data, MathEx::distance, k);
}*/

int n = index.length;
int k = neighbors[0].length;
/**
* Creates an approximate nearest neighbor graph with the NN-Descent algorithm.
*
* @param data the dataset.
* @param k k-nearest neighbor.
* @param distance the distance function.
* @param maxCandidates the maximum number of candidates in nearest neighbor search.
* @param maxIter the maximum number of iterations.
* @return approximate k-nearest neighbor graph.
*/
/*public static <T> NearestNeighborGraph descent(T[] data, Distance<T> distance, int k, int maxCandidates, int maxIter, double delta, double rho) {
if (k < 2) {
throw new IllegalArgumentException("k must be greater than 1: " + k);
}
int[] reverseIndex = new int[neighbors.length];
for (int i = 0; i < n; i++) {
reverseIndex[index[i]] = i;
int n = data.length;
int[][] neighbors = new int[n][k];
double[][] distances = new double[n][k];
boolean[][] isNew = new boolean[n][k];
for (int i = 0; i < n; i++) {
Arrays.fill(neighbors[i], -1);
Arrays.fill(distances[i], Double.POSITIVE_INFINITY);
}
for (int i = 0; i < n; i++) {
final float[] iRow = data.row(i);
for (final int index : Utils.rejectionSample(nNeighbors, data.rows(), random)) {
final float d = mMetric.distance(iRow, data.row(index));
currentGraph.push(i, d, index, true);
currentGraph.push(index, d, i, true);
}
}
int[][] nearest = new int[n][k];
double[][] dist = new double[n][k];
for (int i = 0; i < n; i++) {
dist[i] = distances[index[i]];
int[] ni = neighbors[index[i]];
for (int j = 0; j < k; j++) {
nearest[i][j] = reverseIndex[ni[j]];
if (rpTreeInit) {
for (final FlatTree tree : forest) {
for (final int[] leaf : tree.getIndices()) {
for (int i = 0; i < leaf.length; ++i) {
final float[] iRow = data.row(leaf[i]);
for (int j = i + 1; j < leaf.length; ++j) {
final float d = mMetric.distance(iRow, data.row(leaf[j]));
currentGraph.push(leaf[i], d, leaf[j], true);
currentGraph.push(leaf[j], d, leaf[i], true);
}
}
}
}
return new NearestNeighborGraph(nearest, dist, index);
}
}
boolean[] rejectStatus = new boolean[maxCandidates];
for (int iter = 0; iter < maxIter; iter++) {
logger.info("NearestNeighborDescent: {} / {}", (n + 1), maxIter);
final Heap candidateNeighbors = currentGraph.buildCandidates(n, k, maxCandidates);
int count = 0;
for (int i = 0; i < n; ++i) {
for (int j = 0; j < maxCandidates; ++j) {
rejectStatus[j] = MathEx.random() < rho;
}
for (int j = 0; j < maxCandidates; ++j) {
final int p = candidateNeighbors.index(i, j);
if (p < 0) {
continue;
}
for (int l = 0; l <= j; l++) {
final int q = candidateNeighbors.index(i, l);
if (q < 0 || (rejectStatus[j] && rejectStatus[l]) || (!candidateNeighbors.isNew(i, j) && !candidateNeighbors.isNew(i, l))) {
continue;
}
final float d = mMetric.distance(data.row(p), data.row(q));
if (currentGraph.push(p, d, q, true)) {
++count;
}
if (currentGraph.push(q, d, p, true)) {
++count;
}
}
}
}
if (count <= delta * k * n) {
break;
}
}
return currentGraph.deheapSort();
return new NearestNeighborGraph(neighbors, distances);
}*/
}
4 changes: 2 additions & 2 deletions base/src/main/java/smile/neighbor/MPLSH.java
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ public void fit(RNNSearch<double[], double[]> range, double[][] samples, double
training[i] = new MultiProbeSample(samples[i], new LinkedList<>());
ArrayList<Neighbor<double[], double[]>> neighbors = new ArrayList<>();
range.search(samples[i], radius, neighbors);
for (Neighbor<double[], double[]> n : neighbors) {
training[i].neighbors.add(keys.get(n.index));
for (var neighbor : neighbors) {
training[i].neighbors.add(keys.get(neighbor.index()));
}
}

Expand Down
46 changes: 10 additions & 36 deletions base/src/main/java/smile/neighbor/Neighbor.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,43 +27,16 @@
* object in the dataset, which is often useful, and the distance between
* the query key to the object key.
*
* @param key the key of neighbor.
* @param value the value of neighbor.
* @param index the index of neighbor object in the dataset.
* @param distance the distance between the query and the neighbor.
* @param <K> the type of keys.
* @param <V> the type of associated objects.
*
* @author Haifeng Li
*/
public class Neighbor<K, V> implements Comparable<Neighbor<K,V>> {
/**
* The key of neighbor.
*/
public final K key;
/**
* The data object of neighbor. It may be same as the key object.
*/
public final V value;
/**
* The index of neighbor object in the dataset.
*/
public final int index;
/**
* The distance between the query and the neighbor.
*/
public final double distance;

/**
* Constructor.
* @param key the key of neighbor.
* @param object the value of neighbor.
* @param index the index of neighbor object in the dataset.
* @param distance the distance between the query and the neighbor.
*/
public Neighbor(K key, V object, int index, double distance) {
this.key = key;
this.value = object;
this.index = index;
this.distance = distance;
}

public record Neighbor<K, V>(K key, V value, int index, double distance) implements Comparable<Neighbor<K,V>> {
@Override
public int compareTo(Neighbor<K,V> o) {
int d = Double.compare(distance, o.distance);
Expand All @@ -74,15 +47,16 @@ public int compareTo(Neighbor<K,V> o) {

@Override
public String toString() {
return String.format("%s(%d):%s", key, index, Strings.format(distance));
return String.format("Neighbor(%s[%d]: %s)", key, index, Strings.format(distance));
}

/**
* Creates a neighbor object, of which key and object are the same.
* @param key the query key.
* @param index the index of object.
*
* @param key the query key.
* @param index the index of object.
* @param distance the distance between query key and neighbor.
* @param <T> the data type of key and object.
* @param <T> the data type of key and object.
* @return the neighbor object.
*/
public static <T> Neighbor<T, T> of(T key, int index, double distance) {
Expand Down
2 changes: 2 additions & 0 deletions base/src/main/java/smile/neighbor/NeighborBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

/**
* The mutable object as a template to create a Neighbor object.
* This helps reduce the creation of a lot of temporary objects
* as we can update this object's values in the heap.
*
* @param <K> the type of keys.
* @param <V> the type of associated objects.
Expand Down
4 changes: 2 additions & 2 deletions base/src/test/java/smile/neighbor/BKTreeTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ public void testRange() {
String[] s1 = new String[n1.size()];
String[] s2 = new String[n2.size()];
for (int j = 0; j < s1.length; j++) {
s1[j] = n1.get(j).value;
s2[j] = n2.get(j).value;
s1[j] = n1.get(j).value();
s2[j] = n2.get(j).value();
}
Arrays.sort(s1);
Arrays.sort(s2);
Expand Down
24 changes: 12 additions & 12 deletions base/src/test/java/smile/neighbor/CoverTreeTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ public void testNearest() {
for (double[] datum : data) {
Neighbor n1 = coverTree.nearest(datum);
Neighbor n2 = naive.nearest(datum);
assertEquals(n1.index, n2.index);
assertEquals(n1.value, n2.value);
assertEquals(n1.distance, n2.distance, 1E-7);
assertEquals(n1.index(), n2.index());
assertEquals(n1.value(), n2.value());
assertEquals(n1.distance(), n2.distance(), 1E-7);
}
}

Expand All @@ -87,9 +87,9 @@ public void testKnn() {
Neighbor[] n2 = naive.search(datum, 10);
assertEquals(n1.length, n2.length);
for (int j = 0; j < n1.length; j++) {
assertEquals(n1[j].index, n2[j].index);
assertEquals(n1[j].value, n2[j].value);
assertEquals(n1[j].distance, n2[j].distance, 1E-7);
assertEquals(n1[j].index(), n2[j].index());
assertEquals(n1[j].value(), n2[j].value());
assertEquals(n1[j].distance(), n2[j].distance(), 1E-7);
}
}
}
Expand All @@ -107,9 +107,9 @@ public void testKnn1() {

Neighbor[] n1 = coverTree.search(data[1], 1);
assertEquals(1, n1.length);
assertEquals(0, n1[0].index);
assertEquals(data[0], n1[0].value);
assertEquals(MathEx.distance(data[0], data[1]), n1[0].distance, 1E-7);
assertEquals(0, n1[0].index());
assertEquals(data[0], n1[0].value());
assertEquals(MathEx.distance(data[0], data[1]), n1[0].distance(), 1E-7);
}

@Test
Expand All @@ -129,9 +129,9 @@ public void testRange() {
Collections.sort(n2);
assertEquals(n1.size(), n2.size());
for (int j = 0; j < n1.size(); j++) {
assertEquals(n1.get(j).index, n2.get(j).index);
assertEquals(n1.get(j).value, n2.get(j).value);
assertEquals(n1.get(j).distance, n2.get(j).distance, 1E-7);
assertEquals(n1.get(j).index(), n2.get(j).index());
assertEquals(n1.get(j).value(), n2.get(j).value());
assertEquals(n1.get(j).distance(), n2.get(j).distance(), 1E-7);
}
n1.clear();
n2.clear();
Expand Down
Loading

0 comments on commit aa2bfb9

Please sign in to comment.