package mill.javalib

import mill.api.PathRef
import mill.api.Result
import mill.util.JarManifest
import mill.api.*
import mill.api.Task.Simple as T
import mill.javalib.Assembly.UnopenedInputStream
import mill.util.Jvm

import scala.annotation.nowarn

/**
 * Module that provides functionality around creating and configuring JVM assembly jars
 */
trait AssemblyModule extends OfflineSupportModule {
  outer =>

  def finalMainClassOpt: T[Either[String, String]]

  def forkArgs: T[Seq[String]]

  /**
   * Similar to `forkArgs` but only applies to the `sh` launcher script
   */
  def forkShellArgs: T[Seq[String]] = Task { Seq.empty[String] }

  /**
   * Similar to `forkArgs` but only applies to the `bat` launcher script
   */
  def forkCmdArgs: T[Seq[String]] = Task { Seq.empty[String] }

  /**
   * Creates a manifest representation which can be modified or replaced
   * The default implementation just adds the `Manifest-Version`, `Main-Class` and `Created-By` attributes
   */
  def manifest: T[JarManifest] = Task { manifest0() }

  private[mill] def manifest0: T[JarManifest] = Task {
    Jvm.createManifest(finalMainClassOpt().toOption)
  }

  /**
   * What shell script to use to launch the executable generated by `assembly`.
   * Defaults to a generic "universal" launcher that should work for Windows,
   * OS-X and Linux
   */
  def prependShellScript: T[String] = Task {
    prependShellScript0()
  }
  private[mill] def prependShellScript0: T[String] = Task {
    finalMainClassOpt().toOption match {
      case None => ""
      case Some(cls) =>
        mill.util.Jvm.launcherUniversalScript(
          mainClass = cls,
          shellClassPath = Seq("$0"),
          cmdClassPath = Seq("%~dpnx0"),
          jvmArgs = forkArgs(),
          shebang = false,
          shellJvmArgs = forkShellArgs(),
          cmdJvmArgs = forkCmdArgs()
        )
    }
  }

  def assemblyRules: Seq[Assembly.Rule] = assemblyRules0

  private[mill] def assemblyRules0: Seq[Assembly.Rule] = Assembly.defaultRules

  /**
   * Upstream classfiles and resources from third-party libraries
   * necessary to build an executable assembly
   */
  def upstreamIvyAssemblyClasspath: T[Seq[PathRef]]

  /**
   * Upstream classfiles and resources from locally-built modules
   * necessary to build an executable assembly, but without this module's contribution
   */
  def upstreamLocalAssemblyClasspath: T[Seq[PathRef]]

  def localClasspath: T[Seq[PathRef]]

  /**
   * Build the assembly for third-party dependencies separate from the current
   * classpath
   *
   * This should allow much faster assembly creation in the common case where
   * third-party dependencies do not change
   */
  def resolvedIvyAssembly: T[Assembly] = Task {
    Assembly.create(
      destJar = Task.dest / "out.jar",
      inputPaths = upstreamIvyAssemblyClasspath().map(_.path),
      manifest = manifest(),
      assemblyRules = assemblyRules,
      shader = AssemblyModule.jarjarabramsWorker()
    )
  }

  /**
   * Build the assembly for upstream dependencies separate from the current
   * classpath
   *
   * This should allow much faster assembly creation in the common case where
   * upstream dependencies do not change
   */
  def upstreamAssembly: T[Assembly] = Task {
    Assembly.create(
      destJar = Task.dest / "out.jar",
      inputPaths = upstreamLocalAssemblyClasspath().map(_.path),
      manifest = manifest(),
      base = Some(resolvedIvyAssembly()),
      assemblyRules = assemblyRules,
      shader = AssemblyModule.jarjarabramsWorker()
    )
  }

  /**
   * An executable uber-jar/assembly containing all the resources and compiled
   * classfiles from this module and all it's upstream modules and dependencies
   */
  def assembly: T[PathRef] = Task {
    val prependScript = Option(prependShellScript()).filter(_ != "")
    val upstream = upstreamAssembly()

    val created = Assembly.create(
      destJar = Task.dest / "out.jar",
      inputPaths = Seq.from(localClasspath().map(_.path)),
      manifest = manifest(),
      prependShellScript = prependScript,
      base = Some(upstream),
      assemblyRules = assemblyRules,
      shader = AssemblyModule.jarjarabramsWorker()
    )
    // See https://github.com/com-lihaoyi/mill/pull/2655#issuecomment-1672468284
    val problematicEntryCount = 65535

    if (prependScript.isDefined && created.entries > problematicEntryCount) {
      Task.fail(
        s"""The created assembly jar contains more than ${problematicEntryCount} ZIP entries.
           |JARs of that size are known to not work correctly with a prepended shell script.
           |Either reduce the entries count of the assembly or disable the prepended shell script with:
           |
           |  def prependShellScript = ""
           |""".stripMargin
      )
    } else {
      created.pathRef
    }
  }

  override def prepareOffline(all: mainargs.Flag): Task.Command[Seq[PathRef]] = Task.Command {
    (
      super.prepareOffline(all)() ++
        upstreamIvyAssemblyClasspath() ++
        AssemblyModule.prepareOffline(all)()
    ).distinct
  }
}

object AssemblyModule extends ExternalModule with CoursierModule with OfflineSupportModule {

  def jarjarabramsWorkerClasspath: T[Seq[PathRef]] = Task {
    defaultResolver().classpath(Seq(
      Dep.millProjectModule("mill-libs-javalib-jarjarabrams-worker")
    ))
  }

  override def prepareOffline(all: mainargs.Flag): Task.Command[Seq[PathRef]] = Task.Command {
    (
      super.prepareOffline(all)() ++
        jarjarabramsWorkerClasspath()
    ).distinct
  }

  private[mill] def jarjarabramsWorkerClassloader: Task.Worker[ClassLoader] = Task.Worker {
    Jvm.createClassLoader(
      classPath = jarjarabramsWorkerClasspath().map(_.path),
      parent = getClass().getClassLoader()
    )
  }

  @nowarn("msg=.*Workers should implement AutoCloseable.*")
  def jarjarabramsWorker
      : Task.Worker[(Seq[(String, String)], String, UnopenedInputStream) => Option[(
          String,
          UnopenedInputStream
      )]] = Task.Worker {
    (relocates: Seq[(String, String)], name: String, is: UnopenedInputStream) =>
      jarjarabramsWorkerClassloader()
        .loadClass("mill.javalib.jarjarabrams.impl.JarJarAbramsWorkerImpl")
        .getMethods
        .filter(_.getName == "apply")
        .head
        .invoke(null, relocates, name, is)
        .asInstanceOf[Option[(String, UnopenedInputStream)]]
  }

  override def millDiscover: Discover = Discover[this.type]
}
