package mill.javalib.zinc

private trait TransformingReporter(
    color: Boolean,
    optPositionMapper: (xsbti.Position => xsbti.Position) | Null,
    workspaceRoot: os.Path
) extends xsbti.Reporter {

  // Overriding this is necessary because for some reason the LoggedReporter doesn't transform positions
  // of Actions and DiagnosticRelatedInformation
  abstract override def log(problem0: xsbti.Problem): Unit = {
    val localMapper = optPositionMapper
    // Always transform to apply path relativization, even if there's no position mapper for build files
    val mapper = if localMapper == null then (pos: xsbti.Position) => pos else localMapper
    val problem = TransformingReporter.transformProblem(color, problem0, mapper, workspaceRoot)
    super.log(problem)
  }
}

private object TransformingReporter {

  import sbt.util.InterfaceUtil

  import scala.jdk.CollectionConverters.given

  /** implements a transformation that returns the same object if the mapper has no effect. */
  private def transformProblem(
      color: Boolean,
      problem0: xsbti.Problem,
      mapper: xsbti.Position => xsbti.Position,
      workspaceRoot: os.Path
  ): xsbti.Problem = {
    val unMappedPos = problem0.position()
    val related0 = problem0.diagnosticRelatedInformation()
    val actions0 = problem0.actions()
    val pos = mapper(unMappedPos)
    val related = transformRelateds(related0, mapper)
    val actions = transformActions(actions0, mapper)
    val rendered =
      dottyStyleMessage(color, problem0, pos = pos, unMappedPos = unMappedPos, workspaceRoot)
    InterfaceUtil.problem(
      cat = problem0.category(),
      pos = pos,
      msg = problem0.message(),
      sev = problem0.severity(),
      rendered = Some(rendered),
      diagnosticCode = InterfaceUtil.jo2o(problem0.diagnosticCode()),
      diagnosticRelatedInformation = anyToList(related),
      actions = anyToList(actions)
    )
  }

  private type JOrSList[T] = java.util.List[T] | List[T]

  private def anyToList[T](ts: JOrSList[T]): List[T] = ts match {
    case ts: List[T] => ts
    case ts: java.util.List[T] => ts.asScala.toList
  }

  /** Render the message in the style of dotty */
  private def dottyStyleMessage(
      color: Boolean,
      problem0: xsbti.Problem,
      pos: xsbti.Position,
      unMappedPos: xsbti.Position,
      workspaceRoot: os.Path
  ): String = {

    val severity = problem0.severity()

    val shade: java.util.function.Function[String, String] =
      if color then
        severity match {
          case xsbti.Severity.Error => msg => Console.RED + msg + Console.RESET
          case xsbti.Severity.Warn => msg => Console.YELLOW + msg + Console.RESET
          case xsbti.Severity.Info => msg => Console.BLUE + msg + Console.RESET
        }
      else msg => msg

    val message = problem0.message()

    InterfaceUtil.jo2o(pos.sourcePath()) match {
      case None => message
      case Some(path) =>
        val absPath = os.Path(path)
        // Render paths within the current workspaceRoot as relative paths to cut down on verbosity
        val displayPath =
          if absPath.startsWith(workspaceRoot) then absPath.subRelativeTo(workspaceRoot).toString
          else path

        val line = intValue(pos.line(), -1)
        val pointer0 = intValue(pos.pointer(), -1)
        val colNum = pointer0 + 1

        val space = pos.pointerSpace().orElse("")
        val endCol = intValue(pos.endColumn(), pointer0 + 1)

        // Dotty only renders the colored code snippet as part of `.rendered`, but it's mixed
        // in with the rest of the UI we don't really want. So we need to scrape it out ourselves
        val renderedLines = InterfaceUtil.jo2o(problem0.rendered())
          .iterator
          .flatMap(_.linesIterator)
          .toSeq

        // Scrape the relevant line from the dotty error code snippet, because dotty defaults to
        // rendering entire expressions which can be arbitrarily large and spammy in the terminal.
        val scraped = mill.api.internal.Util.scrapeColoredLineContent(
          renderedLines,
          // Use the unmapped line to scrape the corresponding line from the error message,
          // since the raw compiler error would not have gone through line mapping
          intValue(unMappedPos.line(), -1),
          pos.lineContent()
        )

        // Some errors like Java `unclosed string literal` errors don't provide any
        // message at all to `rendered` for us to scrape the line content, and others
        // like `cannot find symbol` have incorrect line `.lineContent()`s, so for
        // all Java errors just scrape the line from the filesystem
        val isJavaFile = absPath.ext == "java"
        val lineContent0 = if (scraped == "" || isJavaFile) {
          try os.read.lines(absPath).apply(line - 1)
          catch { case _: Exception => "" }
        } else scraped

        // Apply syntax highlighting to Java source code lines
        val lineContent =
          if (color && isJavaFile && lineContent0.nonEmpty) {
            HighlightJava.highlightJavaCode(
              lineContent0,
              literalColor = fansi.Color.Green,
              keywordColor = fansi.Color.Yellow,
              commentColor = fansi.Color.Blue,
              definitionColor = fansi.Color.Cyan
            ).render
          } else lineContent0

        val pointerLength =
          if (space.nonEmpty && pointer0 >= 0 && endCol >= 0)
            math.max(
              1,
              math.min(
                endCol - pointer0,
                // Make sure to use the plaintext length of lineContent,
                // since it may have color codes
                fansi.Str(lineContent).length - space.length
              )
            )
          else 1

        mill.constants.Util.formatError(
          displayPath,
          line,
          colNum,
          lineContent,
          message,
          pointerLength,
          shade
        )
    }
  }

  /** Implements a transformation that returns the same list if the mapper has no effect */
  private def transformActions(
      actions0: java.util.List[xsbti.Action],
      mapper: xsbti.Position => xsbti.Position
  ): JOrSList[xsbti.Action] = {
    if actions0.iterator().asScala.exists(a =>
        a.edit().changes().iterator().asScala.exists(e =>
          mapper(e.position()) ne e.position()
        )
      )
    then {
      actions0.iterator().asScala.map(transformAction(_, mapper)).toList
    } else {
      actions0
    }
  }

  /** Implements a transformation that returns the same list if the mapper has no effect */
  private def transformRelateds(
      related0: java.util.List[xsbti.DiagnosticRelatedInformation],
      mapper: xsbti.Position => xsbti.Position
  ): JOrSList[xsbti.DiagnosticRelatedInformation] = {

    if related0.iterator().asScala.exists(r => mapper(r.position()) ne r.position()) then
      related0.iterator().asScala.map(transformRelated(_, mapper)).toList
    else
      related0
  }

  private def transformRelated(
      related0: xsbti.DiagnosticRelatedInformation,
      mapper: xsbti.Position => xsbti.Position
  ): xsbti.DiagnosticRelatedInformation = {
    InterfaceUtil.diagnosticRelatedInformation(mapper(related0.position()), related0.message())
  }

  private def transformAction(
      action0: xsbti.Action,
      mapper: xsbti.Position => xsbti.Position
  ): xsbti.Action = {
    InterfaceUtil.action(
      title = action0.title(),
      description = InterfaceUtil.jo2o(action0.description()),
      edit = transformEdit(action0.edit(), mapper)
    )
  }

  private def transformEdit(
      edit0: xsbti.WorkspaceEdit,
      mapper: xsbti.Position => xsbti.Position
  ): xsbti.WorkspaceEdit = {
    InterfaceUtil.workspaceEdit(
      edit0.changes().iterator().asScala.map(transformTEdit(_, mapper)).toList
    )
  }

  private def transformTEdit(
      edit0: xsbti.TextEdit,
      mapper: xsbti.Position => xsbti.Position
  ): xsbti.TextEdit = {
    InterfaceUtil.textEdit(
      position = mapper(edit0.position()),
      newText = edit0.newText()
    )
  }
}
