Skip to content

Commit

Permalink
Fix: PyObject array overloads precedence
Browse files Browse the repository at this point in the history
  • Loading branch information
jhonabreul committed May 24, 2024
1 parent c69fb42 commit c799b7e
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 12 deletions.
73 changes: 72 additions & 1 deletion src/embed_tests/TestMethodBinder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,34 @@ public string ImplicitConversionSameArgumentCount2(string symbol, decimal quanti
{
return "ImplicitConversionSameArgumentCount2 2";
}

// ----

public string VariableArgumentsMethod(params CSharpModel[] paramsParams)
{
return "VariableArgumentsMethod(CSharpModel[])";
}

public string VariableArgumentsMethod(params PyObject[] paramsParams)
{
return "VariableArgumentsMethod(PyObject[])";
}

public string ConstructorMessage { get; set; }

public OverloadsTestClass(params CSharpModel[] paramsParams)
{
ConstructorMessage = "OverloadsTestClass(CSharpModel[])";
}

public OverloadsTestClass(params PyObject[] paramsParams)
{
ConstructorMessage = "OverloadsTestClass(PyObject[])";
}

public OverloadsTestClass()
{
}
}

[TestCase("Method1('abc', namedArg1=10, namedArg2=321)", "Method1 Overload 1")]
Expand Down Expand Up @@ -907,7 +935,7 @@ public void BindsConstructorToSnakeCasedArgumentsVersion([Values] bool useCamelC
var argument2Name = useCamelCase ? "anotherArgument" : "another_argument";
var argument2Code = passOptionalArgument ? $", {argument2Name}=\"another argument value\"" : "";

var module = PyModule.FromString("CallsCorrectOverloadWithoutErrors", @$"
var module = PyModule.FromString("BindsConstructorToSnakeCasedArgumentsVersion", @$"
from clr import AddReference
AddReference(""System"")
from Python.EmbeddingTest import *
Expand All @@ -925,6 +953,49 @@ def create_instance():
Assert.AreEqual(expectedMessage, sourceException.Message);
}

[Test]
public void PyObjectArrayHasPrecedenceOverOtherTypeArrays()
{
using var _ = Py.GIL();

var module = PyModule.FromString("PyObjectArrayHasPrecedenceOverOtherTypeArrays", @$"
from clr import AddReference
AddReference(""System"")
from Python.EmbeddingTest import *
class PythonModel(TestMethodBinder.CSharpModel):
pass
def call_method():
return TestMethodBinder.OverloadsTestClass().VariableArgumentsMethod(PythonModel(), PythonModel())
");

var result = module.GetAttr("call_method").Invoke().As<string>();
Assert.AreEqual("VariableArgumentsMethod(PyObject[])", result);
}

[Test]
public void PyObjectArrayHasPrecedenceOverOtherTypeArraysInConstructors()
{
using var _ = Py.GIL();

var module = PyModule.FromString("PyObjectArrayHasPrecedenceOverOtherTypeArrays", @$"
from clr import AddReference
AddReference(""System"")
from Python.EmbeddingTest import *
class PythonModel(TestMethodBinder.CSharpModel):
pass
def get_instance():
return TestMethodBinder.OverloadsTestClass(PythonModel(), PythonModel())
");

var instance = module.GetAttr("get_instance").Invoke();
Assert.AreEqual("OverloadsTestClass(PyObject[])", instance.GetAttr("ConstructorMessage").As<string>());
}


// Used to test that we match this function with Py DateTime & Date Objects
public static int GetMonth(DateTime test)
{
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/ClassManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ void AddMember(string name, string snakeCasedName, bool isStaticReadonlyCallable
}
methodList.Add(ctor, true);
// Same constructor, but with snake-cased arguments
if (ctor.GetParameters().Any(pi => pi.Name.ToSnakeCase() != pi.Name))
if (ctor.GetParameters().Any(pi => pi.Name?.ToSnakeCase() != pi.Name))
{
methodList.Add(ctor, false);
}
Expand Down
20 changes: 10 additions & 10 deletions src/runtime/MethodBinder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,16 @@ internal static int ArgPrecedence(Type t, MethodInformation mi)
return -1;
}

if (t.IsArray)
{
Type e = t.GetElementType();
if (e == objectType)
{
return 2500;
}
return 100 + ArgPrecedence(e, mi);
}

TypeCode tc = Type.GetTypeCode(t);
// TODO: Clean up
switch (tc)
Expand Down Expand Up @@ -406,16 +416,6 @@ internal static int ArgPrecedence(Type t, MethodInformation mi)
return 40;
}

if (t.IsArray)
{
Type e = t.GetElementType();
if (e == objectType)
{
return 2500;
}
return 100 + ArgPrecedence(e, mi);
}

return 2000;
}

Expand Down

0 comments on commit c799b7e

Please sign in to comment.