package mill.api.internal

import fastparse.NoWhitespace.noWhitespaceImplicit
import fastparse.*
import mill.api.{Result, Segment, Segments, SelectMode}

import scala.annotation.tailrec

/**
 * Parsing utilities for Mill task selectors and module references.
 * This is internal to Mill and not part of the public API.
 */
private[mill] object ParseArgs {

  /**
   * Parses a shell command string into arguments, handling single and double quotes
   * and backslash escapes. Similar to how bash parses arguments, and compatible with
   * JDK_JAVA_OPTIONS quoting rules.
   */
  def parseShellArgs(input: String): Seq[String] = {
    // Single-quoted: no escape sequences (like bash)
    def singleQuoted[$: P]: P[String] = P("'" ~/ CharsWhile(_ != '\'', 0).! ~ "'")

    // Double-quoted: handle \" and \\ escapes
    def doubleQuotedChar[$: P]: P[String] = P(
      ("\\" ~ AnyChar.!).map {
        case "\"" => "\""
        case "\\" => "\\"
        case c => "\\" + c // preserve unknown escapes
      } | CharsWhile(c => c != '"' && c != '\\', 1).!
    )
    def doubleQuoted[$: P]: P[String] = P("\"" ~/ doubleQuotedChar.rep.map(_.mkString) ~ "\"")

    // Unquoted: handle backslash escapes for spaces and special chars
    def unquotedChar[$: P]: P[String] = P(
      ("\\" ~ AnyChar.!).map(identity) | // escaped char becomes literal
        CharsWhile(c => c != ' ' && c != '\t' && c != '\'' && c != '"' && c != '\\', 1).!
    )
    def unquoted[$: P]: P[String] = P(unquotedChar.rep(1).map(_.mkString))

    def argPart[$: P]: P[String] = P(singleQuoted | doubleQuoted | unquoted)
    def arg[$: P]: P[String] = P(argPart.rep(1).map(_.mkString))
    def whitespace[$: P]: P[Unit] = P(CharsWhileIn(" \t", 1))
    def parser[$: P]: P[Seq[String]] =
      P(whitespace.? ~ arg.rep(sep = whitespace) ~ whitespace.? ~ End)

    fastparse.parse(input, parser(using _)) match {
      case Parsed.Success(result, _) => result
      case f: Parsed.Failure =>
        throw new IllegalArgumentException(s"Failed to parse shell args: ${f.msg}")
    }
  }

  type TasksWithParams = (Seq[(String, Segments)], Seq[String])

  /** Separator used in multiSelect-mode to separate tasks from their args. */
  val MultiArgsSeparator = "--"

  /** Separator used in [[SelectMode.Separated]] mode to separate a task-args-tuple from the next target. */
  val TaskSeparator = "+"
  val MaskPattern = ("""\\+\Q""" + TaskSeparator + """\E""").r
  def separate(scriptArgs: Seq[String]) = {

    /**
     * Partition the arguments in groups using a separator.
     * To also use the separator as argument, masking it with a backslash (`\`) is supported.
     */
    @tailrec
    def separated(result: Seq[Seq[String]], rest: Seq[String]): Seq[Seq[String]] = rest match {
      case Seq() => if (result.nonEmpty) result else Seq(Seq())
      case r =>
        val (next, r2) = r.span(_ != TaskSeparator)
        separated(
          result ++ Seq(next.map {
            case x @ MaskPattern(_*) => x.drop(1)
            case x => x
          }),
          r2.drop(1)
        )
    }
    separated(Seq() /* start value */, scriptArgs)
  }
  def apply(scriptArgs: Seq[String], selectMode: SelectMode): Seq[Result[TasksWithParams]] = {
    separate(scriptArgs).map(extractAndValidate(_, selectMode == SelectMode.Multi))
  }

  def extractAndValidate(
      scriptArgs: Seq[String],
      multiSelect: Boolean
  ): Result[TasksWithParams] = {
    val (selectors, args) = extractSelsAndArgs(scriptArgs, multiSelect)
    for {
      _ <- validateSelectors(selectors)
      expandedSelectors <- Result
        .sequence(selectors.map(ExpandBraces.expandBraces))
        .map(_.flatten)
      selectors <- Result.sequence(expandedSelectors.map(extractSegments))
    } yield (selectors.iterator.toList, args)
  }

  def extractSelsAndArgs(
      scriptArgs: Seq[String],
      multiSelect: Boolean
  ): (Seq[String], Seq[String]) = {

    if (multiSelect) {
      val dd = scriptArgs.indexOf(MultiArgsSeparator)
      val selectors = if (dd == -1) scriptArgs else scriptArgs.take(dd)
      val args = if (dd == -1) Seq.empty else scriptArgs.drop(dd + 1)

      (selectors, args)
    } else {
      (scriptArgs.take(1), scriptArgs.drop(1))
    }
  }

  private def validateSelectors(selectors: Seq[String]): Result[Unit] = {
    if (selectors.isEmpty || selectors.exists(_.isEmpty)) {
      Result.Failure(
        "Task selector must not be empty. Try `mill resolve _` to see what's available."
      )
    } else Result.Success(())
  }

  def extractSegments(selectorString: String)
      : Result[(String, Segments)] =
    parse(selectorString, selector(using _)) match {
      case f: Parsed.Failure => Result.Failure(s"Parsing exception ${f.msg}")
      case Parsed.Success(selector, _) => Result.Success(selector)
    }

  private def selector[_p: P]: P[(String, Segments)] = {
    def wildcard = P("__" | "_")
    def label = P(CharsWhileIn("a-zA-Z0-9_\\-")).!
    // Match "foo.super" as a single label to support super task invocation
    def labelWithSuper = P(label ~~ ".super").!

    def typeQualifier(simple: Boolean) = {
      val maxSegments = if (simple) 0 else Int.MaxValue
      P(("^" | "!").? ~~ label ~~ ("." ~~ label).rep(max = maxSegments)).!
    }

    def typePattern(simple: Boolean) = P(wildcard ~~ (":" ~~ typeQualifier(simple)).rep(1)).!

    def segment0(simple: Boolean) =
      P(typePattern(simple) | labelWithSuper | label).map(Segment.Label(_))
    def segment = P("(" ~ segment0(false) ~ ")" | segment0(true))

    def identCross = P(CharsWhileIn("a-zA-Z0-9_\\-.")).!
    def crossSegment = P("[" ~ identCross.rep(1, sep = ",") ~ "]").map(Segment.Cross(_))
    def defaultCrossSegment = P("[]").map(_ => Segment.Cross(Seq()))

    def simpleQuery = P(
      (segment | crossSegment | defaultCrossSegment) ~ ("." ~ segment | crossSegment | defaultCrossSegment).rep
    ).map {
      case (h, rest) => Segments(h +: rest)
    }

    P(simpleQuery ~ (("/" | ":").! ~ simpleQuery.?).? ~ End).map {
      case (q, None) => ("", q)
      case (q, Some((sep, q2))) => (q.render + sep, q2.getOrElse(Segments()))
    }
  }
}
