Skip to content

Commit

Permalink
Assembly: Added keepRules which runs on the final jar
Browse files Browse the repository at this point in the history
This allows users of this lib/plugin to keep only files from libraries
they use in their project without specifying them in zap
  • Loading branch information
shanielh committed Dec 11, 2022
1 parent b074c40 commit 216b65a
Show file tree
Hide file tree
Showing 11 changed files with 130 additions and 19 deletions.
2 changes: 2 additions & 0 deletions src/main/contraband/AssemblyOption.contra
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ type AssemblyOption {

shadeRules: sbtassembly.Assembly.SeqShadeRules! = raw"sbtassembly.Assembly.defaultShadeRules" @since("0.15.0")

keepRules: sbtassembly.Assembly.SeqString! = raw"sbtassembly.Assembly.defaultKeepRules" @since("2.0.1")

scalaVersion: String! = "" @since("0.15.0")

level: sbt.Level.Value! = raw"sbt.Level.Info" @since("0.15.0")
Expand Down
67 changes: 55 additions & 12 deletions src/main/scala/sbtassembly/Assembly.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,29 @@ package sbtassembly
import com.eed3si9n.jarjarabrams._
import sbt.Def.Initialize
import sbt.Keys._
import sbt.Package.{ manifestFormat, JarManifest, MainClass, ManifestAttributes }
import sbt.Package.{JarManifest, MainClass, ManifestAttributes, manifestFormat}
import sbt.internal.util.HListFormats._
import sbt.internal.util.HNil
import sbt.internal.util.Types.:+:
import sbt.io.{ DirectoryFilter => _, IO => _, Path => _, Using }
import sbt.io.{Using, DirectoryFilter => _, IO => _, Path => _}
import sbt.util.FileInfo.lastModified
import sbt.util.Tracked.{ inputChanged, lastOutput }
import sbt.util.{ FilesInfo, Level, ModifiedFileInfo }
import sbt.{ File, Logger, _ }
import sbt.util.Tracked.{inputChanged, lastOutput}
import sbt.util.{FilesInfo, Level, ModifiedFileInfo}
import sbt.{File, Logger, _}
import sbt.Tags.Tag
import CacheImplicits._
import sbtassembly.AssemblyPlugin.autoImport.{ Assembly => _, _ }
import com.eed3si9n.jarjar.util.EntryStruct
import com.eed3si9n.jarjar.{JJProcessor, Keep}
import sbtassembly.AssemblyPlugin.autoImport.{Assembly => _, _}
import sbtassembly.PluginCompat.ClasspathUtilities

import java.io._
import java.net.URI
import java.nio.file.attribute.{ BasicFileAttributeView, FileTime, PosixFilePermission }
import java.nio.file.{ Path, _ }
import java.nio.file.attribute.{BasicFileAttributeView, FileTime, PosixFilePermission}
import java.nio.file.{Path, _}
import java.security.MessageDigest
import java.time.Instant
import java.util.jar.{ Attributes => JAttributes, JarFile, Manifest => JManifest }
import java.util.jar.{JarFile, Attributes => JAttributes, Manifest => JManifest}
import scala.annotation.tailrec
import scala.collection.GenSeq
import scala.collection.JavaConverters._
Expand All @@ -36,7 +38,8 @@ object Assembly {
type SeqShadeRules = Seq[com.eed3si9n.jarjarabrams.ShadeRule]
type LazyInputStream = () => InputStream

val defaultShadeRules: Seq[com.eed3si9n.jarjarabrams.ShadeRule] = Nil
val defaultShadeRules: SeqShadeRules = Nil
val defaultKeepRules: SeqString = Nil
val newLine: String = "\n"
val indent: String = " " * 2
val newLineIndented: String = newLine + indent
Expand Down Expand Up @@ -336,11 +339,15 @@ object Assembly {
timed(Level.Debug, "Finding remaining conflicts that were not merged") {
reportConflictsMissedByTheMerge(mergedEntries, log)
}
val finalEntries = timed(Level.Debug, "Applying keep rules") {
val entries = mergedEntries.flatMap(_.entries)
keepShader(ao.keepRules, log, entries)
}
val jarEntriesToWrite = timed(Level.Debug, "Sort/Parallelize merged entries") {
if (ao.repeatableBuild) // we need the jars in a specific order to have a consistent hash
mergedEntries.flatMap(_.entries).seq.sortBy(_.target)
finalEntries.seq.sortBy(_.target)
else // we actually gain performance when creating the jar in parallel, but we won't have a consistent hash
mergedEntries.flatMap(_.entries).par
finalEntries.par
}
val localTime = timestamp
.map(t => t - java.util.TimeZone.getDefault.getOffset(t))
Expand Down Expand Up @@ -565,6 +572,42 @@ object Assembly {
}
}

private[sbtassembly] def keepShader(keepRules: SeqString, log: Logger, entries: Seq[JarEntry]): Seq[JarEntry] = {
val jjRules = keepRules.map(pattern => {
val jRule = new Keep()
jRule.setPattern(pattern)
jRule
})

val proc = new JJProcessor(jjRules, true, true, null)
log.info(s"kp.isEnabled: ${proc.keepList.map(_.getPattern)}")
log.info(s"keepRules: ${keepRules}")

val entryStructs = entries.map({ entry =>
val stream = entry.stream()
val entryStruct = new EntryStruct()
val mapping = entry.target
entryStruct.name = if (mapping.contains('\\')) mapping.replace('\\', '/') else mapping
entryStruct.data = Streamable.bytes(stream)
entryStruct.time = -1
entryStruct.skipTransform = false
stream.close()
entryStruct
})

val countTrue = entryStructs.count(proc.process)
log.info(s"kp.isEnabled: ${proc.kp.asInstanceOf[Object].getClass.getField().isEnabled}")
log.info(s"process.true: ${countTrue}, false: ${entryStructs.length - countTrue}")
val itemsToExclude = proc.getExcludes
log.info(s"items to exclude: ${itemsToExclude.size}")
entryStructs.filterNot(entry => itemsToExclude.contains(entry.name))
.map(entryStruct => {
val mapping = entryStruct.name
val name = if (mapping.contains('/')) mapping.replace('/', '\\') else mapping
JarEntry(name, () => new ByteArrayInputStream(entryStruct.data))
})
}

private[sbtassembly] def createManifest(po: Seq[PackageOption], log: Logger): (JManifest, Option[Long]) = {
import scala.language.reflectiveCalls
val manifest = new JManifest
Expand Down
1 change: 1 addition & 0 deletions src/main/scala/sbtassembly/AssemblyKeys.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ trait AssemblyKeys {
lazy val assemblyExcludedJars = taskKey[Classpath]("list of excluded jars")
lazy val assemblyMergeStrategy = settingKey[String => MergeStrategy]("mapping from archive member path to merge strategy")
lazy val assemblyShadeRules = settingKey[Seq[jarjarabrams.ShadeRule]]("shading rules backed by jarjar")
lazy val assemblyKeepRules = settingKey[Seq[String]]("Keep rules backed by jarjar to run on the final assembled JAR")
lazy val assemblyAppendContentHash = settingKey[Boolean]("Appends SHA-1 fingerprint to the assembly file name")
lazy val assemblyMaxHashLength = settingKey[Int]("Length of SHA-1 fingerprint used for the assembly file name")
lazy val assemblyCacheOutput = settingKey[Boolean]("Enables (true) or disables (false) cacheing the output if the content has not changed")
Expand Down
21 changes: 14 additions & 7 deletions src/main/scala/sbtassembly/AssemblyOption.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,26 @@ final class AssemblyOption private (
val prependShellScript: Option[sbtassembly.Assembly.SeqString],
val maxHashLength: Option[Int],
val shadeRules: sbtassembly.Assembly.SeqShadeRules,
val keepRules: sbtassembly.Assembly.SeqString,
val scalaVersion: String,
val level: sbt.Level.Value) extends Serializable {

private def this() = this(true, true, true, Nil, true, sbtassembly.MergeStrategy.defaultMergeStrategy, true, false, None, None, sbtassembly.Assembly.defaultShadeRules, "", sbt.Level.Info)
private def this(includeBin: Boolean, includeScala: Boolean, includeDependency: Boolean, excludedJars: sbt.Keys.Classpath, mergeStrategy: sbtassembly.MergeStrategy.StringToMergeStrategy, cacheOutput: Boolean, appendContentHash: Boolean, prependShellScript: Option[sbtassembly.Assembly.SeqString], maxHashLength: Option[Int], shadeRules: sbtassembly.Assembly.SeqShadeRules, scalaVersion: String, level: sbt.Level.Value) = this(includeBin, includeScala, includeDependency, excludedJars, true, mergeStrategy, cacheOutput, appendContentHash, prependShellScript, maxHashLength, shadeRules, scalaVersion, level)
private def this() = this(true, true, true, Nil, true, sbtassembly.MergeStrategy.defaultMergeStrategy, true, false, None, None, sbtassembly.Assembly.defaultShadeRules, sbtassembly.Assembly.defaultKeepRules, "", sbt.Level.Info)
private def this(includeBin: Boolean, includeScala: Boolean, includeDependency: Boolean, excludedJars: sbt.Keys.Classpath, mergeStrategy: sbtassembly.MergeStrategy.StringToMergeStrategy, cacheOutput: Boolean, appendContentHash: Boolean, prependShellScript: Option[sbtassembly.Assembly.SeqString], maxHashLength: Option[Int], shadeRules: sbtassembly.Assembly.SeqShadeRules, scalaVersion: String, level: sbt.Level.Value) = this(includeBin, includeScala, includeDependency, excludedJars, true, mergeStrategy, cacheOutput, appendContentHash, prependShellScript, maxHashLength, shadeRules, sbtassembly.Assembly.defaultKeepRules, scalaVersion, level)
private def this(includeBin: Boolean, includeScala: Boolean, includeDependency: Boolean, excludedJars: sbt.Keys.Classpath, repeatableBuild: Boolean, mergeStrategy: sbtassembly.MergeStrategy.StringToMergeStrategy, cacheOutput: Boolean, appendContentHash: Boolean, prependShellScript: Option[sbtassembly.Assembly.SeqString], maxHashLength: Option[Int], shadeRules: sbtassembly.Assembly.SeqShadeRules, scalaVersion: String, level: sbt.Level.Value) = this(includeBin, includeScala, includeDependency, excludedJars, repeatableBuild, mergeStrategy, cacheOutput, appendContentHash, prependShellScript, maxHashLength, shadeRules, sbtassembly.Assembly.defaultKeepRules, scalaVersion, level)

override def equals(o: Any): Boolean = this.eq(o.asInstanceOf[AnyRef]) || (o match {
case x: AssemblyOption => (this.includeBin == x.includeBin) && (this.includeScala == x.includeScala) && (this.includeDependency == x.includeDependency) && (this.excludedJars == x.excludedJars) && (this.repeatableBuild == x.repeatableBuild) && (this.mergeStrategy == x.mergeStrategy) && (this.cacheOutput == x.cacheOutput) && (this.appendContentHash == x.appendContentHash) && (this.prependShellScript == x.prependShellScript) && (this.maxHashLength == x.maxHashLength) && (this.shadeRules == x.shadeRules) && (this.scalaVersion == x.scalaVersion) && (this.level == x.level)
case x: AssemblyOption => (this.includeBin == x.includeBin) && (this.includeScala == x.includeScala) && (this.includeDependency == x.includeDependency) && (this.excludedJars == x.excludedJars) && (this.repeatableBuild == x.repeatableBuild) && (this.mergeStrategy == x.mergeStrategy) && (this.cacheOutput == x.cacheOutput) && (this.appendContentHash == x.appendContentHash) && (this.prependShellScript == x.prependShellScript) && (this.maxHashLength == x.maxHashLength) && (this.shadeRules == x.shadeRules) && (this.keepRules == x.keepRules) && (this.scalaVersion == x.scalaVersion) && (this.level == x.level)
case _ => false
})
override def hashCode: Int = {
37 * (37 * (37 * (37 * (37 * (37 * (37 * (37 * (37 * (37 * (37 * (37 * (37 * (37 * (17 + "sbtassembly.AssemblyOption".##) + includeBin.##) + includeScala.##) + includeDependency.##) + excludedJars.##) + repeatableBuild.##) + mergeStrategy.##) + cacheOutput.##) + appendContentHash.##) + prependShellScript.##) + maxHashLength.##) + shadeRules.##) + scalaVersion.##) + level.##)
37 * (37 * (37 * (37 * (37 * (37 * (37 * (37 * (37 * (37 * (37 * (37 * (37 * (37 * (37 * (17 + "sbtassembly.AssemblyOption".##) + includeBin.##) + includeScala.##) + includeDependency.##) + excludedJars.##) + repeatableBuild.##) + mergeStrategy.##) + cacheOutput.##) + appendContentHash.##) + prependShellScript.##) + maxHashLength.##) + shadeRules.##) + keepRules.##) + scalaVersion.##) + level.##)
}
override def toString: String = {
"AssemblyOption(" + includeBin + ", " + includeScala + ", " + includeDependency + ", " + excludedJars + ", " + repeatableBuild + ", " + mergeStrategy + ", " + cacheOutput + ", " + appendContentHash + ", " + prependShellScript + ", " + maxHashLength + ", " + shadeRules + ", " + scalaVersion + ", " + level + ")"
"AssemblyOption(" + includeBin + ", " + includeScala + ", " + includeDependency + ", " + excludedJars + ", " + repeatableBuild + ", " + mergeStrategy + ", " + cacheOutput + ", " + appendContentHash + ", " + prependShellScript + ", " + maxHashLength + ", " + shadeRules + ", " + keepRules + ", " + scalaVersion + ", " + level + ")"
}
private[this] def copy(includeBin: Boolean = includeBin, includeScala: Boolean = includeScala, includeDependency: Boolean = includeDependency, excludedJars: sbt.Keys.Classpath = excludedJars, repeatableBuild: Boolean = repeatableBuild, mergeStrategy: sbtassembly.MergeStrategy.StringToMergeStrategy = mergeStrategy, cacheOutput: Boolean = cacheOutput, appendContentHash: Boolean = appendContentHash, prependShellScript: Option[sbtassembly.Assembly.SeqString] = prependShellScript, maxHashLength: Option[Int] = maxHashLength, shadeRules: sbtassembly.Assembly.SeqShadeRules = shadeRules, scalaVersion: String = scalaVersion, level: sbt.Level.Value = level): AssemblyOption = {
new AssemblyOption(includeBin, includeScala, includeDependency, excludedJars, repeatableBuild, mergeStrategy, cacheOutput, appendContentHash, prependShellScript, maxHashLength, shadeRules, scalaVersion, level)
private[this] def copy(includeBin: Boolean = includeBin, includeScala: Boolean = includeScala, includeDependency: Boolean = includeDependency, excludedJars: sbt.Keys.Classpath = excludedJars, repeatableBuild: Boolean = repeatableBuild, mergeStrategy: sbtassembly.MergeStrategy.StringToMergeStrategy = mergeStrategy, cacheOutput: Boolean = cacheOutput, appendContentHash: Boolean = appendContentHash, prependShellScript: Option[sbtassembly.Assembly.SeqString] = prependShellScript, maxHashLength: Option[Int] = maxHashLength, shadeRules: sbtassembly.Assembly.SeqShadeRules = shadeRules, keepRules: sbtassembly.Assembly.SeqString = keepRules, scalaVersion: String = scalaVersion, level: sbt.Level.Value = level): AssemblyOption = {
new AssemblyOption(includeBin, includeScala, includeDependency, excludedJars, repeatableBuild, mergeStrategy, cacheOutput, appendContentHash, prependShellScript, maxHashLength, shadeRules, keepRules, scalaVersion, level)
}
def withIncludeBin(includeBin: Boolean): AssemblyOption = {
copy(includeBin = includeBin)
Expand Down Expand Up @@ -78,6 +80,9 @@ final class AssemblyOption private (
def withShadeRules(shadeRules: sbtassembly.Assembly.SeqShadeRules): AssemblyOption = {
copy(shadeRules = shadeRules)
}
def withKeepRules(keepRules: sbtassembly.Assembly.SeqString): AssemblyOption = {
copy(keepRules = keepRules)
}
def withScalaVersion(scalaVersion: String): AssemblyOption = {
copy(scalaVersion = scalaVersion)
}
Expand All @@ -92,4 +97,6 @@ object AssemblyOption {
def apply(includeBin: Boolean, includeScala: Boolean, includeDependency: Boolean, excludedJars: sbt.Keys.Classpath, mergeStrategy: sbtassembly.MergeStrategy.StringToMergeStrategy, cacheOutput: Boolean, appendContentHash: Boolean, prependShellScript: sbtassembly.Assembly.SeqString, maxHashLength: Int, shadeRules: sbtassembly.Assembly.SeqShadeRules, scalaVersion: String, level: sbt.Level.Value): AssemblyOption = new AssemblyOption(includeBin, includeScala, includeDependency, excludedJars, mergeStrategy, cacheOutput, appendContentHash, Option(prependShellScript), Option(maxHashLength), shadeRules, scalaVersion, level)
def apply(includeBin: Boolean, includeScala: Boolean, includeDependency: Boolean, excludedJars: sbt.Keys.Classpath, repeatableBuild: Boolean, mergeStrategy: sbtassembly.MergeStrategy.StringToMergeStrategy, cacheOutput: Boolean, appendContentHash: Boolean, prependShellScript: Option[sbtassembly.Assembly.SeqString], maxHashLength: Option[Int], shadeRules: sbtassembly.Assembly.SeqShadeRules, scalaVersion: String, level: sbt.Level.Value): AssemblyOption = new AssemblyOption(includeBin, includeScala, includeDependency, excludedJars, repeatableBuild, mergeStrategy, cacheOutput, appendContentHash, prependShellScript, maxHashLength, shadeRules, scalaVersion, level)
def apply(includeBin: Boolean, includeScala: Boolean, includeDependency: Boolean, excludedJars: sbt.Keys.Classpath, repeatableBuild: Boolean, mergeStrategy: sbtassembly.MergeStrategy.StringToMergeStrategy, cacheOutput: Boolean, appendContentHash: Boolean, prependShellScript: sbtassembly.Assembly.SeqString, maxHashLength: Int, shadeRules: sbtassembly.Assembly.SeqShadeRules, scalaVersion: String, level: sbt.Level.Value): AssemblyOption = new AssemblyOption(includeBin, includeScala, includeDependency, excludedJars, repeatableBuild, mergeStrategy, cacheOutput, appendContentHash, Option(prependShellScript), Option(maxHashLength), shadeRules, scalaVersion, level)
def apply(includeBin: Boolean, includeScala: Boolean, includeDependency: Boolean, excludedJars: sbt.Keys.Classpath, repeatableBuild: Boolean, mergeStrategy: sbtassembly.MergeStrategy.StringToMergeStrategy, cacheOutput: Boolean, appendContentHash: Boolean, prependShellScript: Option[sbtassembly.Assembly.SeqString], maxHashLength: Option[Int], shadeRules: sbtassembly.Assembly.SeqShadeRules, keepRules: sbtassembly.Assembly.SeqString, scalaVersion: String, level: sbt.Level.Value): AssemblyOption = new AssemblyOption(includeBin, includeScala, includeDependency, excludedJars, repeatableBuild, mergeStrategy, cacheOutput, appendContentHash, prependShellScript, maxHashLength, shadeRules, keepRules, scalaVersion, level)
def apply(includeBin: Boolean, includeScala: Boolean, includeDependency: Boolean, excludedJars: sbt.Keys.Classpath, repeatableBuild: Boolean, mergeStrategy: sbtassembly.MergeStrategy.StringToMergeStrategy, cacheOutput: Boolean, appendContentHash: Boolean, prependShellScript: sbtassembly.Assembly.SeqString, maxHashLength: Int, shadeRules: sbtassembly.Assembly.SeqShadeRules, keepRules: sbtassembly.Assembly.SeqString, scalaVersion: String, level: sbt.Level.Value): AssemblyOption = new AssemblyOption(includeBin, includeScala, includeDependency, excludedJars, repeatableBuild, mergeStrategy, cacheOutput, appendContentHash, Option(prependShellScript), Option(maxHashLength), shadeRules, keepRules, scalaVersion, level)
}
2 changes: 2 additions & 0 deletions src/main/scala/sbtassembly/AssemblyPlugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ object AssemblyPlugin extends sbt.AutoPlugin {
override lazy val globalSettings: Seq[Def.Setting[_]] = Seq(
assemblyMergeStrategy := MergeStrategy.defaultMergeStrategy,
assemblyShadeRules := Nil,
assemblyKeepRules := Nil,
assemblyExcludedJars := Nil,
assembleArtifact in packageBin := true,
assembleArtifact in assemblyPackageScala := true,
Expand Down Expand Up @@ -118,6 +119,7 @@ object AssemblyPlugin extends sbt.AutoPlugin {
.withPrependShellScript(assemblyPrependShellScript.value)
.withMaxHashLength(assemblyMaxHashLength.?.value)
.withShadeRules(assemblyShadeRules.value)
.withKeepRules(assemblyKeepRules.value)
.withScalaVersion(scalaVersion.value)
.withLevel(logLevel.?.value.getOrElse(Level.Info))
.withRepeatableBuild(assemblyRepeatableBuild.value)
Expand Down
29 changes: 29 additions & 0 deletions src/sbt-test/shading/keeprules/build.sbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
lazy val scala212 = "2.12.15"
lazy val scala213 = "2.13.7"

scalaVersion := scala212
crossScalaVersions := List(scala212, scala213)

lazy val keeprules = (project in file(".")).
settings(
version := "0.1",
assembly / assemblyJarName := "foo.jar",
libraryDependencies += "org.apache.commons" % "commons-lang3" % "3.8.1",
assembly / assemblyKeepRules := Seq("keep.**"),
TaskKey[Unit]("check") := {
IO.withTemporaryDirectory { dir
IO.unzip(crossTarget.value / "foo.jar", dir)
mustNotExist(dir / "removed" / "ShadeClass.class")
mustNotExist(dir / "removed" / "ShadePackage.class")
mustExist(dir / "keep" / "Keeped.class")
mustExist(dir / "org" / "apache" / "commons" / "lang3" / "time" / "TimeZones.class")
mustNotExist(dir / "org" / "apache" / "commons" / "lang3" / "time" / "DateParser.class")
}
})

def mustNotExist(f: File): Unit = {
if (f.exists) sys.error("file" + f + " exists!")
}
def mustExist(f: File): Unit = {
if (!f.exists) sys.error("file" + f + " does not exist!")
}
7 changes: 7 additions & 0 deletions src/sbt-test/shading/keeprules/project/plugins.sbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
val pluginVersion = System.getProperty("plugin.version")
if(pluginVersion == null)
throw new RuntimeException("""|The system property 'plugin.version' is not defined.
|Specify this property using the scriptedLaunchOpts -D.""".stripMargin)
else addSbtPlugin("com.eed3si9n" % "sbt-assembly" % pluginVersion)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package keep

class Keeped {
def main(args: Array[String]): Unit = {
val myUsedObject = org.apache.commons.lang3.time.TimeZones.GMT_ID
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package removed

class ShadeClass
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package removed

class ShadePackage
7 changes: 7 additions & 0 deletions src/sbt-test/shading/keeprules/test
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# check if the file gets created
> +assembly
$ exists target/scala-2.12/foo.jar
$ exists target/scala-2.13/foo.jar

# check if it says hello
> +check

0 comments on commit 216b65a

Please sign in to comment.