Skip to content

Commit

Permalink
Implementation of CRAM 3.1 codecs.
Browse files Browse the repository at this point in the history
  • Loading branch information
yash-puligundla authored and cmnbroad committed Sep 4, 2024
1 parent 46f6f0f commit 12a3e4c
Show file tree
Hide file tree
Showing 55 changed files with 5,907 additions and 1,086 deletions.
177 changes: 177 additions & 0 deletions src/main/java/htsjdk/samtools/cram/compression/CompressionUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package htsjdk.samtools.cram.compression;

import htsjdk.samtools.cram.CRAMException;
import htsjdk.samtools.cram.compression.rans.Constants;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;

public class CompressionUtils {
public static void writeUint7(final int i, final ByteBuffer cp) {
int s = 0;
int X = i;
do {
s += 7;
X >>= 7;
} while (X > 0);
do {
s -= 7;
//writeByte
final int s_ = (s > 0) ? 1 : 0;
cp.put((byte) (((i >> s) & 0x7f) + (s_ << 7)));
} while (s > 0);
}

public static int readUint7(final ByteBuffer cp) {
int i = 0;
int c;
do {
//read byte
c = cp.get();
i = (i << 7) | (c & 0x7f);
} while ((c & 0x80) != 0);
return i;
}

public static ByteBuffer encodePack(
final ByteBuffer inBuffer,
final ByteBuffer outBuffer,
final int[] frequencyTable,
final int[] packMappingTable,
final int numSymbols){
final int inSize = inBuffer.remaining();
final ByteBuffer encodedBuffer;
if (numSymbols <= 1) {
encodedBuffer = CompressionUtils.allocateByteBuffer(0);
} else if (numSymbols <= 2) {

// 1 bit per value
final int encodedBufferSize = (int) Math.ceil((double) inSize/8);
encodedBuffer = CompressionUtils.allocateByteBuffer(encodedBufferSize);
int j = -1;
for (int i = 0; i < inSize; i ++) {
if (i % 8 == 0) {
encodedBuffer.put(++j, (byte) 0);
}
encodedBuffer.put(j, (byte) (encodedBuffer.get(j) + (packMappingTable[inBuffer.get(i) & 0xFF] << (i % 8))));
}
} else if (numSymbols <= 4) {

// 2 bits per value
final int encodedBufferSize = (int) Math.ceil((double) inSize/4);
encodedBuffer = CompressionUtils.allocateByteBuffer(encodedBufferSize);
int j = -1;
for (int i = 0; i < inSize; i ++) {
if (i % 4 == 0) {
encodedBuffer.put(++j, (byte) 0);
}
encodedBuffer.put(j, (byte) (encodedBuffer.get(j) + (packMappingTable[inBuffer.get(i) & 0xFF] << ((i % 4) * 2))));
}
} else {

// 4 bits per value
final int encodedBufferSize = (int) Math.ceil((double)inSize/2);
encodedBuffer = CompressionUtils.allocateByteBuffer(encodedBufferSize);
int j = -1;
for (int i = 0; i < inSize; i ++) {
if (i % 2 == 0) {
encodedBuffer.put(++j, (byte) 0);
}
encodedBuffer.put(j, (byte) (encodedBuffer.get(j) + (packMappingTable[inBuffer.get(i) & 0xFF] << ((i % 2) * 4))));
}
}

// write numSymbols
outBuffer.put((byte) numSymbols);

// write mapping table "packMappingTable" that converts mapped value to original symbol
for(int i = 0; i < Constants.NUMBER_OF_SYMBOLS; i ++) {
if (frequencyTable[i] > 0) {
outBuffer.put((byte) i);
}
}

// write the length of data
CompressionUtils.writeUint7(encodedBuffer.limit(), outBuffer);
return encodedBuffer; // Here position = 0 since we have always accessed the data buffer using index
}

public static ByteBuffer decodePack(
final ByteBuffer inBuffer,
final byte[] packMappingTable,
final int numSymbols,
final int uncompressedPackOutputLength) {
final ByteBuffer outBufferPack = CompressionUtils.allocateByteBuffer(uncompressedPackOutputLength);
int j = 0;
if (numSymbols <= 1) {
for (int i=0; i < uncompressedPackOutputLength; i++){
outBufferPack.put(i, packMappingTable[0]);
}
}

// 1 bit per value
else if (numSymbols <= 2) {
int v = 0;
for (int i=0; i < uncompressedPackOutputLength; i++){
if (i % 8 == 0){
v = inBuffer.get(j++);
}
outBufferPack.put(i, packMappingTable[v & 1]);
v >>=1;
}
}

// 2 bits per value
else if (numSymbols <= 4){
int v = 0;
for(int i=0; i < uncompressedPackOutputLength; i++){
if (i % 4 == 0){
v = inBuffer.get(j++);
}
outBufferPack.put(i, packMappingTable[v & 3]);
v >>=2;
}
}

// 4 bits per value
else if (numSymbols <= 16){
int v = 0;
for(int i=0; i < uncompressedPackOutputLength; i++){
if (i % 2 == 0){
v = inBuffer.get(j++);
}
outBufferPack.put(i, packMappingTable[v & 15]);
v >>=4;
}
}
return outBufferPack;
}

public static ByteBuffer allocateOutputBuffer(final int inSize) {
// This calculation is identical to the one in samtools rANS_static.c
// Presumably the frequency table (always big enough for order 1) = 257*257,
// then * 3 for each entry (byte->symbol, 2 bytes -> scaled frequency),
// + 9 for the header (order byte, and 2 int lengths for compressed/uncompressed lengths).
final int compressedSize = (int) (inSize + 257 * 257 * 3 + 9);
final ByteBuffer outputBuffer = allocateByteBuffer(compressedSize);
if (outputBuffer.remaining() < compressedSize) {
throw new CRAMException("Failed to allocate sufficient buffer size for RANS coder.");
}
return outputBuffer;
}

// returns a new LITTLE_ENDIAN ByteBuffer of size = bufferSize
public static ByteBuffer allocateByteBuffer(final int bufferSize){
return ByteBuffer.allocate(bufferSize).order(ByteOrder.LITTLE_ENDIAN);
}

// returns a LITTLE_ENDIAN ByteBuffer that is created by wrapping a byte[]
public static ByteBuffer wrap(final byte[] inputBytes){
return ByteBuffer.wrap(inputBytes).order(ByteOrder.LITTLE_ENDIAN);
}

// returns a LITTLE_ENDIAN ByteBuffer that is created by inputBuffer.slice()
public static ByteBuffer slice(final ByteBuffer inputBuffer){
return inputBuffer.slice().order(ByteOrder.LITTLE_ENDIAN);
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package htsjdk.samtools.cram.compression;

import htsjdk.samtools.cram.compression.rans.RANS;
import htsjdk.samtools.cram.compression.range.RangeDecode;
import htsjdk.samtools.cram.compression.range.RangeEncode;
import htsjdk.samtools.cram.compression.rans.rans4x8.RANS4x8Decode;
import htsjdk.samtools.cram.compression.rans.rans4x8.RANS4x8Encode;
import htsjdk.samtools.cram.structure.block.BlockCompressionMethod;
import htsjdk.utils.ValidationUtils;

Expand Down Expand Up @@ -71,8 +74,13 @@ public static ExternalCompressor getCompressorForMethod(

case RANS:
return compressorSpecificArg == NO_COMPRESSION_ARG ?
new RANSExternalCompressor(new RANS()) :
new RANSExternalCompressor(compressorSpecificArg, new RANS());
new RANSExternalCompressor(new RANS4x8Encode(), new RANS4x8Decode()) :
new RANSExternalCompressor(compressorSpecificArg, new RANS4x8Encode(), new RANS4x8Decode());

case RANGE:
return compressorSpecificArg == NO_COMPRESSION_ARG ?
new RangeExternalCompressor(new RangeEncode(), new RangeDecode()) :
new RangeExternalCompressor(compressorSpecificArg, new RangeEncode(), new RangeDecode());

case BZIP2:
ValidationUtils.validateArg(
Expand All @@ -85,5 +93,4 @@ public static ExternalCompressor getCompressorForMethod(
}
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,48 +24,60 @@
*/
package htsjdk.samtools.cram.compression;

import htsjdk.samtools.cram.compression.rans.RANS;
import htsjdk.samtools.cram.compression.rans.RANSParams;
import htsjdk.samtools.cram.compression.rans.rans4x8.RANS4x8Decode;
import htsjdk.samtools.cram.compression.rans.rans4x8.RANS4x8Encode;
import htsjdk.samtools.cram.compression.rans.rans4x8.RANS4x8Params;
import htsjdk.samtools.cram.structure.block.BlockCompressionMethod;

import java.nio.ByteBuffer;
import java.util.Objects;

public final class RANSExternalCompressor extends ExternalCompressor {
private final RANS.ORDER order;
private final RANS rans;
private final RANSParams.ORDER order;
private final RANS4x8Encode ransEncode;
private final RANS4x8Decode ransDecode;

/**
* We use a shared RANS instance for all compressors.
* @param rans
*/
public RANSExternalCompressor(final RANS rans) {
this(RANS.ORDER.ZERO, rans);
public RANSExternalCompressor(
final RANS4x8Encode ransEncode,
final RANS4x8Decode ransDecode) {
this(RANSParams.ORDER.ZERO, ransEncode, ransDecode);
}

public RANSExternalCompressor(final int order, final RANS rans) {
this(RANS.ORDER.fromInt(order), rans);
public RANSExternalCompressor(
final int order,
final RANS4x8Encode ransEncode,
final RANS4x8Decode ransDecode) {
this(RANSParams.ORDER.fromInt(order), ransEncode, ransDecode);
}

public RANSExternalCompressor(final RANS.ORDER order, final RANS rans) {
public RANSExternalCompressor(
final RANSParams.ORDER order,
final RANS4x8Encode ransEncode,
final RANS4x8Decode ransDecode) {
super(BlockCompressionMethod.RANS);
this.rans = rans;
this.ransEncode = ransEncode;
this.ransDecode = ransDecode;
this.order = order;
}

@Override
public byte[] compress(final byte[] data) {
final ByteBuffer buffer = rans.compress(ByteBuffer.wrap(data), order);
final RANS4x8Params params = new RANS4x8Params(order);
final ByteBuffer buffer = ransEncode.compress(CompressionUtils.wrap(data), params);
return toByteArray(buffer);
}

@Override
public byte[] uncompress(byte[] data) {
final ByteBuffer buf = rans.uncompress(ByteBuffer.wrap(data));
final ByteBuffer buf = ransDecode.uncompress(CompressionUtils.wrap(data));
return toByteArray(buf);
}

public RANS.ORDER getOrder() { return order; }

@Override
public String toString() {
return String.format("%s(%s)", this.getMethod(), order);
Expand Down Expand Up @@ -96,4 +108,4 @@ private byte[] toByteArray(final ByteBuffer buffer) {
return bytes;
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package htsjdk.samtools.cram.compression;

import htsjdk.samtools.cram.compression.range.RangeDecode;
import htsjdk.samtools.cram.compression.range.RangeEncode;
import htsjdk.samtools.cram.compression.range.RangeParams;
import htsjdk.samtools.cram.structure.block.BlockCompressionMethod;

import java.nio.ByteBuffer;

public class RangeExternalCompressor extends ExternalCompressor{

private final int formatFlags;
private final RangeEncode rangeEncode;
private final RangeDecode rangeDecode;

public RangeExternalCompressor(
final RangeEncode rangeEncode,
final RangeDecode rangeDecode) {
this(0, rangeEncode, rangeDecode);
}

public RangeExternalCompressor(
final int formatFlags,
final RangeEncode rangeEncode,
final RangeDecode rangeDecode) {
super(BlockCompressionMethod.RANGE);
this.rangeEncode = rangeEncode;
this.rangeDecode = rangeDecode;
this.formatFlags = formatFlags;
}

@Override
public byte[] compress(byte[] data) {
final RangeParams params = new RangeParams(formatFlags);
final ByteBuffer buffer = rangeEncode.compress(CompressionUtils.wrap(data), params);
return toByteArray(buffer);
}

@Override
public byte[] uncompress(byte[] data) {
final ByteBuffer buf = rangeDecode.uncompress(CompressionUtils.wrap(data));
return toByteArray(buf);
}

@Override
public String toString() {
return String.format("%s(%s)", this.getMethod(),formatFlags);
}

private byte[] toByteArray(final ByteBuffer buffer) {
if (buffer.hasArray() && buffer.arrayOffset() == 0 && buffer.array().length == buffer.limit()) {
return buffer.array();
}

final byte[] bytes = new byte[buffer.remaining()];
buffer.get(bytes);
return bytes;
}


}
Loading

0 comments on commit 12a3e4c

Please sign in to comment.