diff --git a/modules/build/src/main/scala/scala/build/preprocessing/SheBang.scala b/modules/build/src/main/scala/scala/build/preprocessing/SheBang.scala index 3b1b29fb56..46dcfb3a8f 100644 --- a/modules/build/src/main/scala/scala/build/preprocessing/SheBang.scala +++ b/modules/build/src/main/scala/scala/build/preprocessing/SheBang.scala @@ -7,6 +7,19 @@ object SheBang { def isShebangScript(content: String): Boolean = sheBangRegex.unanchored.matches(content) + /** Returns the shebang section and the content without the shebang section */ + def partitionOnShebangSection(content: String): (String, String) = + if (content.startsWith("#!")) { + val regexMatch = sheBangRegex.findFirstMatchIn(content) + regexMatch match { + case Some(firstMatch) => + (firstMatch.toString(), content.replace(firstMatch.toString(), "")) + case None => ("", content) + } + } + else + ("", content) + def ignoreSheBangLines(content: String): (String, Boolean) = if (content.startsWith("#!")) { val regexMatch = sheBangRegex.findFirstMatchIn(content) diff --git a/modules/cli/src/main/scala/scala/cli/commands/fix/Fix.scala b/modules/cli/src/main/scala/scala/cli/commands/fix/Fix.scala index d130d7a8cb..fd0feeade1 100644 --- a/modules/cli/src/main/scala/scala/cli/commands/fix/Fix.scala +++ b/modules/cli/src/main/scala/scala/cli/commands/fix/Fix.scala @@ -148,9 +148,9 @@ object Fix extends ScalaCommand[FixOptions] { val logger = loggingUtilities.logger val fromPaths = sources.paths.map { (path, _) => - val content = os.read(path).toCharArray + val (_, content) = SheBang.partitionOnShebangSection(os.read(path)) logger.debug(s"Extracting directives from ${loggingUtilities.relativePath(path)}") - ExtractedDirectives.from(content, Right(path), logger, _ => None).orExit(logger) + ExtractedDirectives.from(content.toCharArray, Right(path), logger, _ => None).orExit(logger) } val fromInMemory = sources.inMemory.map { inMem => @@ -158,7 +158,7 @@ object Fix extends ScalaCommand[FixOptions] { val content = originOrPath match { case Right(path) => logger.debug(s"Extracting directives from ${loggingUtilities.relativePath(path)}") - os.read(path).toCharArray + os.read(path) case Left(origin) => logger.debug(s"Extracting directives from $origin") inMem.wrapperParamsOpt match { @@ -166,12 +166,19 @@ object Fix extends ScalaCommand[FixOptions] { case Some(wrapperParams) => String(inMem.content) .linesWithSeparators .drop(wrapperParams.topWrapperLineCount) - .mkString.toCharArray - case None => inMem.content.map(_.toChar) + .mkString + case None => String(inMem.content) } } - ExtractedDirectives.from(content, originOrPath, logger, _ => None).orExit(logger) + val (_, contentWithNoShebang) = SheBang.partitionOnShebangSection(content) + + ExtractedDirectives.from( + contentWithNoShebang.toCharArray, + originOrPath, + logger, + _ => None + ).orExit(logger) } fromPaths ++ fromInMemory @@ -244,8 +251,16 @@ object Fix extends ScalaCommand[FixOptions] { ): Unit = { position match { case Some(Position.File(Right(path), _, _, offset)) => - val keepLines = toKeep.mkString("", newLine, newLine * 2) - val newContents = keepLines + os.read(path).drop(offset).stripLeading() + val (shebangSection, strippedContent) = SheBang.partitionOnShebangSection(os.read(path)) + + def ignoreOrAddNewLine(str: String) = if str.isBlank then "" else str + newLine + + val keepLines = ignoreOrAddNewLine(shebangSection) + ignoreOrAddNewLine(toKeep.mkString( + "", + newLine, + newLine + )) + val newContents = keepLines + strippedContent.drop(offset).stripLeading() val relativePath = loggingUtilities.relativePath(path) loggingUtilities.logger.message(s"Removing directives from $relativePath") diff --git a/modules/integration/src/test/scala/scala/cli/integration/FixTests.scala b/modules/integration/src/test/scala/scala/cli/integration/FixTests.scala index c42c59ef28..48a37932a0 100644 --- a/modules/integration/src/test/scala/scala/cli/integration/FixTests.scala +++ b/modules/integration/src/test/scala/scala/cli/integration/FixTests.scala @@ -73,6 +73,65 @@ class FixTests extends ScalaCliSuite { } } + test("fix script with shebang") { + val mainFileName = "main.sc" + val inputs = TestInputs( + os.rel / mainFileName -> + s"""#!/usr/bin/env -S scala-cli shebang + | + |//> using objectWrapper + |//> using dep com.lihaoyi::os-lib:0.9.1 com.lihaoyi::upickle:3.1.2 + | + |println(os.pwd) + |""".stripMargin, + os.rel / projectFileName -> + s"""//> using lib "com.lihaoyi::pprint:0.6.6" + |""".stripMargin + ) + + inputs.fromRoot { root => + + val fixOutput = os.proc(TestUtil.cli, "--power", "fix", ".", "-v", "-v", extraOptions) + .call(cwd = root, mergeErrIntoOut = true).out.trim() + + assertNoDiff( + fixOutput, + """Extracting directives from project.scala + |Extracting directives from main.sc + |Removing directives from project.scala + |Removing directives from main.sc""".stripMargin + ) + + val projectFileContents = os.read(root / projectFileName) + val mainFileContents = os.read(root / mainFileName) + + assertNoDiff( + projectFileContents, + """// Main + |//> using objectWrapper + | + |//> using dependency "com.lihaoyi::os-lib:0.9.1" + |//> using dependency "com.lihaoyi::pprint:0.6.6" + |//> using dependency "com.lihaoyi::upickle:3.1.2" + | + |""".stripMargin + ) + + assertNoDiff( + mainFileContents, + """#!/usr/bin/env -S scala-cli shebang + | + |println(os.pwd) + |""".stripMargin + ) + + val runProc = os.proc(TestUtil.cli, "--power", "compile", ".", extraOptions) + .call(cwd = root, stderr = os.Pipe) + + expect(!runProc.err.trim.contains("Using directives detected in multiple files")) + } + } + test("fix with test scope") { val mainSubPath = os.rel / "src" / "Main.scala" val testSubPath = os.rel / "test" / "MyTests.scala"