From eba1adafef5121fdec730b075930145336e1b985 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 14 Aug 2023 01:41:31 +0200 Subject: [PATCH] Update --- .../runtime/compress/utils/ACountHashMap.java | 310 ++++++++++++++++++ .../compress/utils/DoubleCountHashMap.java | 304 +---------------- 2 files changed, 326 insertions(+), 288 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/compress/utils/ACountHashMap.java diff --git a/src/main/java/org/apache/sysds/runtime/compress/utils/ACountHashMap.java b/src/main/java/org/apache/sysds/runtime/compress/utils/ACountHashMap.java new file mode 100644 index 00000000000..b36c5170ddd --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/utils/ACountHashMap.java @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.utils; + +import java.util.Arrays; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +public abstract class ACountHashMap { + protected static final Log LOG = LogFactory.getLog(ACountHashMap.class.getName()); + protected static final int RESIZE_FACTOR = 2; + protected static final float LOAD_FACTOR = 0.80f; + protected static final int shortCutSize = 10; + + protected int size; + private ACount[] data; + + public ACountHashMap(ACount[] data) { + this.data = data; + this.size = 0; + } + + public ACountHashMap(ACount[] data, int size) { + this.data = data; + this.size = size; + } + + public int size() { + return size; + } + + public final int increment(T key) { + return increment(key, 1); + } + + public final int increment(final T key, final int count) { + return size < shortCutSize ? incrementBypassHash(key, count) : incrementNormal(key, count); + } + + private final int incrementBypassHash(final T key, final int count) { + for(int i = 0; i < data.length; i++) { + + ACount l = data[i]; + while(l != null) { + if(l.key() == key) { + l.count += count; + return l.id; + } + else + l = l.next; + } + } + final int ix = hash(key) % data.length; + return addNewDCounts(ix, key); + } + + private final int incrementNormal(final T key, final int count) { + + final int ix = hash(key) % data.length; + ACount l = data[ix]; + while(l != null) { + if(l.key() == key) { + l.count += count; + return l.id; + } + else + l = l.next; + } + return addNewDCounts(ix, key); + } + + private int addNewDCounts(final int ix, final T key) { + ACount ob = data[ix]; + data[ix] = create(key, size); + data[ix].next = ob; + final int id = size++; + if(size >= LOAD_FACTOR * data.length && size > shortCutSize) + resize(); + return id; + } + + public int get(T key) { + return getC(key).count; + } + + public int getId(T key) { + return getC(key).id; + } + + public ACount getC(T key) { + return size < shortCutSize ? getCByPassHash(key) : getCNormal(key); + } + + private ACount getCByPassHash(T key) { + for(int i = 0; i < data.length; i++) { + ACount l = data[i]; + while(l != null) + if(l.key() == key) + return l; + else + l = l.next; + } + return null; + } + + private ACount getCNormal(T key) { + int ix = hash(key) % data.length; + ACount l = data[ix]; + while(!(l.key() == key)) + l = l.next; + return l; + } + + public int getOrDefault(T key, int def) { + return size < shortCutSize ? getOrDefaultShortCut(key, def) : getOrDefaultHash(key, def); + } + + private int getOrDefaultShortCut(T key, int def) { + for(int i = 0; i < data.length; i++) { + ACount l = data[i]; + while(l != null) + if(l.key() == key) + return l.count; + else + l = l.next; + } + return def; + } + + private int getOrDefaultHash(T key, int def) { + int ix = hash(key) % data.length; + ACount l = data[ix]; + while(l != null && !(l.key() == key)) + l = l.next; + if(l == null) + return def; + return l.count; + } + + public ACount[] extractValues() { + ACount[] ret = create(size); + int i = 0; + for(ACount e : data) { + while(e != null) { + // no cleanup aka potentially linked lists in cells. + ret[i++] = e; + // ret[i++] = e.cleanCopy(); + e = e.next; + } + } + + return ret; + } + + public void replaceWithUIDs() { + int i = 0; + for(ACount e : data) + while(e != null) { + e.count = i++; + e = e.next; + } + } + + public void replaceWithUIDsNoZero() { + int i = 0; + for(ACount e : data) { + while(e != null) { + if(e.key().equals(0.0)) + e.count = i++; + e = e.next; + } + } + } + + public int[] getUnorderedCountsAndReplaceWithUIDs() { + final int[] counts = new int[size]; + int i = 0; + for(ACount e : data) + while(e != null) { + counts[i] = e.count; + e.count = i++; + e = e.next; + } + + return counts; + } + + public int[] getUnorderedCountsAndReplaceWithUIDsWithout0() { + final int[] counts = new int[size]; + int i = 0; + for(ACount e : data) { + while(e != null) { + if(e.key().equals(0.0)) { + counts[i] = e.count; + e.count = i++; + } + e = e.next; + } + } + + return counts; + } + + public T getMostFrequent() { + T f = null; + int fq = 0; + for(ACount e : data) { + while(e != null) { + if(e.count > fq) { + fq = e.count; + f = e.key(); + } + e = e.next; + } + } + return f; + } + + private void resize() { + // check for integer overflow on resize + if(data.length > Integer.MAX_VALUE / RESIZE_FACTOR) + return; + resize(Math.max(data.length * RESIZE_FACTOR, shortCutSize)); + } + + private void resize(int size) { + // resize data array and copy existing contents + final ACount[] olddata = data; + data = create(size); + size = 0; + // rehash all entries + for(ACount e : olddata) + appendValue(e); + } + + private void appendValue(ACount ent) { + if(ent != null) { + // take the tail recursively first + appendValue(ent.next); // append tail first + ent.next = null; // set this tail to null. + final int ix = ent.hashIndex() % data.length; + // sorted insert based on count. + ACount l = data[ix]; + if(data[ix] == null) { + data[ix] = ent; + } + else { + ACount p = l.next; + while(p != null && p.count > ent.count) { + l = p; + p = p.next; + } + l.next = ent; + ent.next = p; + } + size++; + } + } + + public void sortBuckets() { + if(size > 10) + for(int i = 0; i < data.length; i++) + if(data[i] != null) + data[i] = data[i].sort(); + } + + public void reset(int size) { + int p2 = Util.getPow2(size); + if(data.length > 2 * p2) + data = create(p2); + else + Arrays.fill(data, null); + size = 0; + } + + protected abstract ACount[] create(int size); + + protected abstract int hash(T key); + + protected abstract ACount create(T key, int id); + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(this.getClass().getSimpleName()); + for(int i = 0; i < data.length; i++) + if(data[i] != null) + sb.append(", " + data[i]); + return sb.toString(); + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleCountHashMap.java b/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleCountHashMap.java index 8d81211d466..0ad0b7c4c43 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleCountHashMap.java +++ b/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleCountHashMap.java @@ -19,315 +19,43 @@ package org.apache.sysds.runtime.compress.utils; -import java.util.Arrays; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.utils.ACount.DCounts; -public class DoubleCountHashMap { - - protected static final Log LOG = LogFactory.getLog(DoubleCountHashMap.class.getName()); - protected static final int RESIZE_FACTOR = 2; - protected static final float LOAD_FACTOR = 0.80f; - protected static final int shortCutSize = 10; - - protected int _size = -1; - private DCounts[] _data = null; +public class DoubleCountHashMap extends ACountHashMap { public DoubleCountHashMap() { - _data = new DCounts[1]; - _size = 0; + super(new DCounts[1], 0); } public DoubleCountHashMap(int init_capacity) { - if(init_capacity < shortCutSize) - _data = new DCounts[1]; - else - _data = new DCounts[(Util.getPow2(init_capacity) / 2) + 5]; - _size = 0; - } - - public int size() { - return _size; - } - - public final int increment(final double key) { - return increment(key, 1); - } - - public final int increment(final double key, final int count) { - return _size < shortCutSize ? incrementBypassHash(key, count) : incrementNormal(key, count); - } - - private final int incrementBypassHash(final double key, final int count) { - for(int i = 0; i < _data.length; i++) { - - DCounts l = _data[i]; - while(l != null) { - if(l.key() == key) { - l.count += count; - return l.id; - } - else - l = (DCounts) l.next; - } - } - final int ix = DCounts.hashIndex(key) % _data.length; - return addNewDCounts(ix, key); + super(init_capacity < shortCutSize ? // + new DCounts[1] : // + new DCounts[(Util.getPow2(init_capacity) / 2) + 5], 0); } - private final int incrementNormal(final double key, final int count) { - - final int ix = DCounts.hashIndex(key) % _data.length; - DCounts l = _data[ix]; - while(l != null) { - if(l.key() == key) { - l.count += count; - return l.id; - } - else - l = (DCounts) l.next; - } - return addNewDCounts(ix, key); + protected ACount[] create(int size) { + return new DCounts[size]; } - private int addNewDCounts(final int ix, final double key) { - ACount ob = _data[ix]; - _data[ix] = new DCounts(key, _size); - _data[ix].next = ob; - final int id = _size++; - if(_size >= LOAD_FACTOR * _data.length && _size > shortCutSize) - resize(); - return id; + protected int hash(Double key) { + return DCounts.hashIndex(key); } - /** - * Get the value on a key, if the key is not inside a NullPointerException is thrown. - * - * @param key the key to lookup - * @return count on key - */ - public int get(double key) { - return getC(key).count; + protected ACount create(Double key, int id) { + return new DCounts(key, id); } - /** - * Get the ID behind the key, if it does not exist a null pointer is thrown - * - * @param key The key array - * @return The Id or null pointer exception - */ - public int getId(double key) { - return getC(key).id; - } - - public ACount getC(double key) { - return _size < shortCutSize ? getCByPassHash(key) : getCNormal(key); - } - - private ACount getCByPassHash(double key) { - try { - for(int i = 0; i < _data.length; i++) { - ACount l = _data[i]; - while(l != null) - if(l.key() == key) - return l; - else - l = l.next; - } - return null; - } - catch(Exception e) { - if(Double.isNaN(key)) - return getC(0.0); - throw new NullPointerException("Failed to getKey : " + key + " in " + this); - } - } - - private ACount getCNormal(double key) { + @Override + public ACount getC(Double key) { try { - int ix = DCounts.hashIndex(key) % _data.length; - ACount l = _data[ix]; - while(!(l.key() == key)) - l = l.next; - return l; + return super.getC(key); } catch(Exception e) { - if(Double.isNaN(key)) - return getC(0.0); + if(key instanceof Double && Double.isNaN((double) key)) + return getC(Double.valueOf(0.0)); throw new NullPointerException("Failed to getKey : " + key + " in " + this); - } - } - - public int getOrDefault(double key, int def) { - int ix = DCounts.hashIndex(key) % _data.length; - ACount l = _data[ix]; - while(l != null && !(l.key() == key)) - l = l.next; - if(l == null) - return def; - return l.count; - } - /** - * We extract the values without allocation maybe we need to do clean copy? - * - * @return - */ - public DCounts[] extractValues() { - DCounts[] ret = new DCounts[_size]; - int i = 0; - for(ACount e : _data) { - while(e != null) { - // no cleanup aka potentially linked lists in cells. - ret[i++] = (DCounts) e; - // ret[i++] = e.cleanCopy(); - e = e.next; - } } - - return ret; - } - - public void replaceWithUIDs() { - int i = 0; - for(ACount e : _data) - while(e != null) { - e.count = i++; - e = e.next; - } - } - - public void replaceWithUIDsNoZero() { - int i = 0; - for(ACount e : _data) { - while(e != null) { - if(e.key() != 0) - e.count = i++; - e = e.next; - } - } - } - - public int[] getUnorderedCountsAndReplaceWithUIDs() { - final int[] counts = new int[_size]; - int i = 0; - for(ACount e : _data) - while(e != null) { - counts[i] = e.count; - e.count = i++; - e = e.next; - } - - return counts; - } - - public int[] getUnorderedCountsAndReplaceWithUIDsWithout0() { - final int[] counts = new int[_size]; - int i = 0; - for(ACount e : _data) { - while(e != null) { - if(e.key() != 0) { - counts[i] = e.count; - e.count = i++; - } - e = e.next; - } - } - - return counts; - } - - public double getMostFrequent() { - double f = 0; - int fq = 0; - for(ACount e : _data) { - while(e != null) { - if(e.count > fq) { - fq = e.count; - f = e.key(); - } - e = e.next; - } - } - return f; } - private void resize() { - // check for integer overflow on resize - if(_data.length > Integer.MAX_VALUE / RESIZE_FACTOR) - return; - resize(Math.max(_data.length * RESIZE_FACTOR, shortCutSize)); - } - - private void resize(int size) { - // resize data array and copy existing contents - final ACount[] olddata = _data; - _data = new DCounts[size]; - _size = 0; - // rehash all entries - for(ACount e : olddata) - appendValue(e); - } - - private void appendValue(ACount ent) { - if(ent != null) { - // take the tail recursively first - appendValue(ent.next); // append tail first - ent.next = null; // set this tail to null. - final int ix = ent.hashIndex() % _data.length; - // sorted insert based on count. - ACount l = _data[ix]; - if(_data[ix] == null) { - _data[ix] = (DCounts) ent; - } - else { - ACount p = l.next; - while(p != null && p.count > ent.count) { - l = p; - p = p.next; - } - l.next = ent; - ent.next = p; - } - _size++; - } - } - - public void sortBuckets() { - if(_size > 10) - for(int i = 0; i < _data.length; i++) - if(_data[i] != null) - _data[i] = (DCounts) _data[i].sort(); - } - - public double[] getDictionary() { - final double[] ret = new double[_size]; - for(ACount e : _data) - while(e != null) { - ret[e.id] = e.key(); - e = e.next; - } - return ret; - } - - @Override - public String toString() { - StringBuilder sb = new StringBuilder(); - sb.append(this.getClass().getSimpleName()); - for(int i = 0; i < _data.length; i++) - if(_data[i] != null) - sb.append(", " + _data[i]); - return sb.toString(); - } - - public void reset(int size) { - int p2 = Util.getPow2(size); - if(_data.length > 2 * p2) - _data = new DCounts[p2]; - else - Arrays.fill(_data, null); - _size = 0; - } }