Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Baunsgaard committed Aug 13, 2023
1 parent 5d2bfa6 commit eba1ada
Show file tree
Hide file tree
Showing 2 changed files with 326 additions and 288 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,310 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.compress.utils;

import java.util.Arrays;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public abstract class ACountHashMap<T> {
protected static final Log LOG = LogFactory.getLog(ACountHashMap.class.getName());
protected static final int RESIZE_FACTOR = 2;
protected static final float LOAD_FACTOR = 0.80f;
protected static final int shortCutSize = 10;

protected int size;
private ACount<T>[] data;

public ACountHashMap(ACount<T>[] data) {
this.data = data;
this.size = 0;
}

public ACountHashMap(ACount<T>[] data, int size) {
this.data = data;
this.size = size;
}

public int size() {
return size;
}

public final int increment(T key) {
return increment(key, 1);
}

public final int increment(final T key, final int count) {
return size < shortCutSize ? incrementBypassHash(key, count) : incrementNormal(key, count);
}

private final int incrementBypassHash(final T key, final int count) {
for(int i = 0; i < data.length; i++) {

ACount<T> l = data[i];
while(l != null) {
if(l.key() == key) {
l.count += count;
return l.id;
}
else
l = l.next;
}
}
final int ix = hash(key) % data.length;
return addNewDCounts(ix, key);
}

private final int incrementNormal(final T key, final int count) {

final int ix = hash(key) % data.length;
ACount<T> l = data[ix];
while(l != null) {
if(l.key() == key) {
l.count += count;
return l.id;
}
else
l = l.next;
}
return addNewDCounts(ix, key);
}

private int addNewDCounts(final int ix, final T key) {
ACount<T> ob = data[ix];
data[ix] = create(key, size);
data[ix].next = ob;
final int id = size++;
if(size >= LOAD_FACTOR * data.length && size > shortCutSize)
resize();
return id;
}

public int get(T key) {
return getC(key).count;
}

public int getId(T key) {
return getC(key).id;
}

public ACount<T> getC(T key) {
return size < shortCutSize ? getCByPassHash(key) : getCNormal(key);
}

private ACount<T> getCByPassHash(T key) {
for(int i = 0; i < data.length; i++) {
ACount<T> l = data[i];
while(l != null)
if(l.key() == key)
return l;
else
l = l.next;
}
return null;
}

private ACount<T> getCNormal(T key) {
int ix = hash(key) % data.length;
ACount<T> l = data[ix];
while(!(l.key() == key))
l = l.next;
return l;
}

public int getOrDefault(T key, int def) {
return size < shortCutSize ? getOrDefaultShortCut(key, def) : getOrDefaultHash(key, def);
}

private int getOrDefaultShortCut(T key, int def) {
for(int i = 0; i < data.length; i++) {
ACount<T> l = data[i];
while(l != null)
if(l.key() == key)
return l.count;
else
l = l.next;
}
return def;
}

private int getOrDefaultHash(T key, int def) {
int ix = hash(key) % data.length;
ACount<T> l = data[ix];
while(l != null && !(l.key() == key))
l = l.next;
if(l == null)
return def;
return l.count;
}

public ACount<T>[] extractValues() {
ACount<T>[] ret = create(size);
int i = 0;
for(ACount<T> e : data) {
while(e != null) {
// no cleanup aka potentially linked lists in cells.
ret[i++] = e;
// ret[i++] = e.cleanCopy();
e = e.next;
}
}

return ret;
}

public void replaceWithUIDs() {
int i = 0;
for(ACount<T> e : data)
while(e != null) {
e.count = i++;
e = e.next;
}
}

public void replaceWithUIDsNoZero() {
int i = 0;
for(ACount<T> e : data) {
while(e != null) {
if(e.key().equals(0.0))
e.count = i++;
e = e.next;
}
}
}

public int[] getUnorderedCountsAndReplaceWithUIDs() {
final int[] counts = new int[size];
int i = 0;
for(ACount<T> e : data)
while(e != null) {
counts[i] = e.count;
e.count = i++;
e = e.next;
}

return counts;
}

public int[] getUnorderedCountsAndReplaceWithUIDsWithout0() {
final int[] counts = new int[size];
int i = 0;
for(ACount<T> e : data) {
while(e != null) {
if(e.key().equals(0.0)) {
counts[i] = e.count;
e.count = i++;
}
e = e.next;
}
}

return counts;
}

public T getMostFrequent() {
T f = null;
int fq = 0;
for(ACount<T> e : data) {
while(e != null) {
if(e.count > fq) {
fq = e.count;
f = e.key();
}
e = e.next;
}
}
return f;
}

private void resize() {
// check for integer overflow on resize
if(data.length > Integer.MAX_VALUE / RESIZE_FACTOR)
return;
resize(Math.max(data.length * RESIZE_FACTOR, shortCutSize));
}

private void resize(int size) {
// resize data array and copy existing contents
final ACount<T>[] olddata = data;
data = create(size);
size = 0;
// rehash all entries
for(ACount<T> e : olddata)
appendValue(e);
}

private void appendValue(ACount<T> ent) {
if(ent != null) {
// take the tail recursively first
appendValue(ent.next); // append tail first
ent.next = null; // set this tail to null.
final int ix = ent.hashIndex() % data.length;
// sorted insert based on count.
ACount<T> l = data[ix];
if(data[ix] == null) {
data[ix] = ent;
}
else {
ACount<T> p = l.next;
while(p != null && p.count > ent.count) {
l = p;
p = p.next;
}
l.next = ent;
ent.next = p;
}
size++;
}
}

public void sortBuckets() {
if(size > 10)
for(int i = 0; i < data.length; i++)
if(data[i] != null)
data[i] = data[i].sort();
}

public void reset(int size) {
int p2 = Util.getPow2(size);
if(data.length > 2 * p2)
data = create(p2);
else
Arrays.fill(data, null);
size = 0;
}

protected abstract ACount<T>[] create(int size);

protected abstract int hash(T key);

protected abstract ACount<T> create(T key, int id);

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(this.getClass().getSimpleName());
for(int i = 0; i < data.length; i++)
if(data[i] != null)
sb.append(", " + data[i]);
return sb.toString();
}

}
Loading

0 comments on commit eba1ada

Please sign in to comment.