diff --git a/CodeConverter/CSharp/CommonConversions.cs b/CodeConverter/CSharp/CommonConversions.cs index 383c5260d..eb46921d2 100644 --- a/CodeConverter/CSharp/CommonConversions.cs +++ b/CodeConverter/CSharp/CommonConversions.cs @@ -145,7 +145,7 @@ public bool ShouldPreferExplicitType(VBSyntax.ExpressionSyntax exp, equalsValueClauseSyntax = null; } else { var returnBlock = SyntaxFactory.Block(SyntaxFactory.ReturnStatement(adjustedInitializerExpr)); - _typeContext.PerScopeState.Hoist(new HoistedParameterlessFunction(GetInitialValueFunctionName(vbName), csTypeSyntax, returnBlock)); + _typeContext.PerScopeState.Hoist(new HoistedFunction(GetInitialValueFunctionName(vbName), csTypeSyntax, returnBlock, null)); equalsValueClauseSyntax = null; } } else { diff --git a/CodeConverter/CSharp/ExpressionNodeVisitor.cs b/CodeConverter/CSharp/ExpressionNodeVisitor.cs index a0e69a0d8..e3a655c89 100644 --- a/CodeConverter/CSharp/ExpressionNodeVisitor.cs +++ b/CodeConverter/CSharp/ExpressionNodeVisitor.cs @@ -1174,9 +1174,41 @@ private async Task HoistAndCallLocalFunctionAsync(VB statements.Concat(SyntaxFactory.ReturnStatement(ValidSyntaxFactory.IdentifierName(retVariableName)).Yield()) ); var returnType = CommonConversions.GetTypeSyntax(invocationSymbol.ReturnType); - - var localFunc = _typeContext.PerScopeState.Hoist(new HoistedParameterlessFunction(localFuncName, returnType, block)); - return SyntaxFactory.InvocationExpression(localFunc.TempIdentifier, SyntaxFactory.ArgumentList()); + + //any argument that's a ByRef parameter of the parent method needs to be passed as a ref parameter to the local function (to avoid error CS1628) + var refParametersOfParent = GetRefParameters(invocation.ArgumentList); + var (args, @params) = CreateArgumentsAndParametersLists(refParametersOfParent); + + var localFunc = _typeContext.PerScopeState.Hoist(new HoistedFunction(localFuncName, returnType, block, SyntaxFactory.ParameterList(@params))); + return SyntaxFactory.InvocationExpression(localFunc.TempIdentifier, SyntaxFactory.ArgumentList(args)); + + List GetRefParameters(VBSyntax.ArgumentListSyntax argumentList) + { + var result = new List(); + if (argumentList is null) return result; + + foreach (var arg in argumentList.Arguments) { + if (_semanticModel.GetSymbolInfo(arg.GetExpression()).Symbol is not IParameterSymbol p) continue; + if (p.RefKind != RefKind.None) { + result.Add(p); + } + } + + return result; + } + + (SeparatedSyntaxList, SeparatedSyntaxList) CreateArgumentsAndParametersLists(List parameterSymbols) + { + var arguments = new List(); + var parameters = new List(); + foreach (var p in parameterSymbols) { + var arg = (ArgumentSyntax)CommonConversions.CsSyntaxGenerator.Argument(p.RefKind, SyntaxFactory.IdentifierName(p.Name)); + arguments.Add(arg); + var par = (ParameterSyntax)CommonConversions.CsSyntaxGenerator.ParameterDeclaration(p); + parameters.Add(par); + } + return (SyntaxFactory.SeparatedList(arguments), SyntaxFactory.SeparatedList(parameters)); + } } private bool RequiresLocalFunction(VBSyntax.InvocationExpressionSyntax invocation, IMethodSymbol invocationSymbol) diff --git a/CodeConverter/CSharp/HoistedParameterlessFunction.cs b/CodeConverter/CSharp/HoistedFunction.cs similarity index 64% rename from CodeConverter/CSharp/HoistedParameterlessFunction.cs rename to CodeConverter/CSharp/HoistedFunction.cs index c2da25be9..cca1c0c83 100644 --- a/CodeConverter/CSharp/HoistedParameterlessFunction.cs +++ b/CodeConverter/CSharp/HoistedFunction.cs @@ -3,23 +3,25 @@ namespace ICSharpCode.CodeConverter.CSharp; -internal class HoistedParameterlessFunction : IHoistedNode +internal class HoistedFunction : IHoistedNode { private readonly TypeSyntax _returnType; private readonly BlockSyntax _block; + private readonly ParameterListSyntax _parameters; public string Id { get; } public string Prefix { get; } - public HoistedParameterlessFunction(string localFuncName, TypeSyntax returnType, BlockSyntax block) + public HoistedFunction(string localFuncName, TypeSyntax returnType, BlockSyntax block, ParameterListSyntax parameters) { Id = $"hs{Guid.NewGuid().ToString("N")}"; Prefix = localFuncName; _returnType = returnType; _block = block; + _parameters = parameters; } public IdentifierNameSyntax TempIdentifier => ValidSyntaxFactory.IdentifierName(Id).WithAdditionalAnnotations(PerScopeState.AdditionalLocalAnnotation); - public LocalFunctionStatementSyntax AsLocalFunction(string functionName) => SyntaxFactory.LocalFunctionStatement(_returnType, SyntaxFactory.Identifier(functionName)).WithBody(_block); - public MethodDeclarationSyntax AsInstanceMethod(string functionName) => ValidSyntaxFactory.CreateParameterlessMethod(functionName, _returnType, _block); + public LocalFunctionStatementSyntax AsLocalFunction(string functionName) => SyntaxFactory.LocalFunctionStatement(_returnType, SyntaxFactory.Identifier(functionName)).WithParameterList(_parameters).WithBody(_block); + public MethodDeclarationSyntax AsInstanceMethod(string functionName) => ValidSyntaxFactory.CreateMethod(functionName, _returnType, _parameters, _block); } \ No newline at end of file diff --git a/CodeConverter/CSharp/PerScopeState.cs b/CodeConverter/CSharp/PerScopeState.cs index c53645b3d..1dfb06b92 100644 --- a/CodeConverter/CSharp/PerScopeState.cs +++ b/CodeConverter/CSharp/PerScopeState.cs @@ -86,9 +86,9 @@ private StatementSyntax[] GetPostStatements() .ToArray(); } - public IReadOnlyCollection GetParameterlessFunctions() + public IReadOnlyCollection GetParameterlessFunctions() { - return _hoistedNodesPerScope.Peek().OfType().ToArray(); + return _hoistedNodesPerScope.Peek().OfType().ToArray(); } public IReadOnlyCollection GetFields() diff --git a/CodeConverter/CSharp/ValidSyntaxFactory.cs b/CodeConverter/CSharp/ValidSyntaxFactory.cs index 616892315..a55fbe3e3 100644 --- a/CodeConverter/CSharp/ValidSyntaxFactory.cs +++ b/CodeConverter/CSharp/ValidSyntaxFactory.cs @@ -74,10 +74,17 @@ expressionSyntax is IdentifierNameSyntax || } public static MethodDeclarationSyntax CreateParameterlessMethod(string newMethodName, TypeSyntax type, BlockSyntax body) + { + var parameterList = SyntaxFactory.ParameterList(); + return CreateMethod(newMethodName, type, parameterList, body); + } + + public static MethodDeclarationSyntax CreateMethod(string newMethodName, TypeSyntax type, ParameterListSyntax parameterList, BlockSyntax body) { var modifiers = SyntaxFactory.TokenList(SyntaxFactory.Token(SyntaxKind.StaticKeyword)); var typeConstraints = SyntaxFactory.List(); - var parameterList = SyntaxFactory.ParameterList(); + parameterList ??= SyntaxFactory.ParameterList(); + var methodAttrs = SyntaxFactory.List(); ArrowExpressionClauseSyntax arrowExpression = null; diff --git a/Tests/CSharp/MemberTests/MemberTests.cs b/Tests/CSharp/MemberTests/MemberTests.cs index 9abcbb254..a4805006f 100644 --- a/Tests/CSharp/MemberTests/MemberTests.cs +++ b/Tests/CSharp/MemberTests/MemberTests.cs @@ -385,6 +385,60 @@ public VisualBasicClass() } }"); } + + [Fact] + public async Task TestHoistedOutParameterLambdaUsingByRefParameterAsync() + { + await TestConversionVisualBasicToCSharpAsync( + @"Public Class SomeClass + Sub S(Optional ByRef x As Integer = -1) + Dim i As Integer = 0 + If F1(x, i) Then + ElseIf F2(x, i) Then + ElseIf F3(x, i) Then + End If + End Sub + + Function F1(x As Integer, ByRef o As Object) As Boolean : End Function + Function F2(ByRef x As Integer, ByRef o As Object) As Boolean : End Function + Function F3(ByRef x As Object, ByRef o As Object) As Boolean : End Function +End Class", @"using System.Runtime.InteropServices; +using Microsoft.VisualBasic.CompilerServices; // Install-Package Microsoft.VisualBasic + +public partial class SomeClass +{ + public void S([Optional, DefaultParameterValue(-1)] ref int x) + { + int i = 0; + bool localF1(ref int x) { object argo = i; var ret = F1(x, ref argo); i = Conversions.ToInteger(argo); return ret; } + bool localF2(ref int x) { object argo1 = i; var ret = F2(ref x, ref argo1); i = Conversions.ToInteger(argo1); return ret; } + bool localF3(ref int x) { object argx = x; object argo2 = i; var ret = F3(ref argx, ref argo2); x = Conversions.ToInteger(argx); i = Conversions.ToInteger(argo2); return ret; } + + if (localF1(ref x)) + { + } + else if (localF2(ref x)) + { + } + else if (localF3(ref x)) + { + } + } + + public bool F1(int x, ref object o) + { + return default; + } + public bool F2(ref int x, ref object o) + { + return default; + } + public bool F3(ref object x, ref object o) + { + return default; + } +}"); + } [Fact] public async Task TestMethodWithReturnTypeAsync()