Skip to content

Commit

Permalink
use optimized network evaluation shader code
Browse files Browse the repository at this point in the history
  • Loading branch information
julienkay committed May 28, 2023
1 parent 2011cd8 commit dbd39c1
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 140 deletions.
93 changes: 19 additions & 74 deletions Editor/MobileNeRFImporter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using System.Text;
using System.Text.RegularExpressions;
using System.Threading.Tasks;
using Unity.Collections;
using UnityEditor;
using UnityEngine;
using static WebRequestAsyncUtility;
Expand Down Expand Up @@ -74,7 +73,7 @@ public static void ImportAssetsFromDisk() {
}
}

// ask for axis siwtch behaviour
// ask for axis switch behaviour
if (EditorUtility.DisplayDialog(SwitchAxisTitle, SwitchAxisMsg, Switch, NoSwitch)) {
SwizzleAxis = true;
} else {
Expand Down Expand Up @@ -165,7 +164,7 @@ public static void DownloadStumpAssets() {

/// <summary>
/// Some scenes require switching the y and z axis in the shader.
/// For custom scenes this tracks whether which one should be used.
/// For custom scenes this tracks, which one should be used.
/// </summary>
public static bool SwizzleAxis = false;

Expand Down Expand Up @@ -193,12 +192,6 @@ private static string GetMLPAssetPath(string objName) {
Directory.CreateDirectory(Path.GetDirectoryName(path));
return path;
}

private static string GetWeightsAssetPath(string objName, int i) {
string path = $"{GetBasePath(objName)}/MLP/weightsTex{i}.asset";
Directory.CreateDirectory(Path.GetDirectoryName(path));
return path;
}

private static string GetFeatureTextureAssetPath(string objName, int shapeNum, int featureNum) {
string path = $"{GetBasePath(objName)}/PNGs/shape{shapeNum}.pngfeat{featureNum}.png";
Expand Down Expand Up @@ -297,13 +290,12 @@ private static async Task ImportDemoSceneAsync(MNeRFScene scene) {

/// <summary>
/// Set specific import settings on OBJs/PNGs.
/// Creates Weight Textures, Materials and Shader from MLP data.
/// Creates Materials and Shader from MLP data.
/// Creates a convenient prefab for the MobileNeRF object.
/// </summary>
private static void ProcessAssets(string objName) {
Mlp mlp = GetMlp(objName);
CreateShader(objName, mlp);
CreateWeightTextures(objName, mlp);
// PNGs are configured in PNGImportProcessor.cs
ProcessOBJs(objName, mlp);
CreatePrefab(objName, mlp);
Expand Down Expand Up @@ -499,14 +491,6 @@ private static void ProcessOBJs(string objName, Mlp mlp) {
Material material = AssetDatabase.LoadAssetAtPath<Material>(materialAssetPath);
material.shader = mobileNeRFShader;

// assign weight textures
Texture2D weightsTexZero = AssetDatabase.LoadAssetAtPath<Texture2D>(GetWeightsAssetPath(objName, 0));
Texture2D weightsTexOne = AssetDatabase.LoadAssetAtPath<Texture2D>(GetWeightsAssetPath(objName, 1));
Texture2D weightsTexTwo = AssetDatabase.LoadAssetAtPath<Texture2D>(GetWeightsAssetPath(objName, 2));
material.SetTexture("weightsZero", weightsTexZero);
material.SetTexture("weightsOne", weightsTexOne);
material.SetTexture("weightsTwo", weightsTexTwo);

// assign feature textures
string feat0AssetPath = GetFeatureTextureAssetPath(objName, i, 0);
string feat1AssetPath = GetFeatureTextureAssetPath(objName, i, 1);
Expand All @@ -528,25 +512,26 @@ private static void ProcessOBJs(string objName, Mlp mlp) {
private static void CreateShader(string objName, Mlp mlp) {
int width = mlp._0Bias.Length;

StringBuilder biasListZero = toBiasList(mlp._0Bias);
StringBuilder biasListOne = toBiasList(mlp._1Bias);
StringBuilder biasListTwo = toBiasList(mlp._2Bias);

int channelsZero = mlp._0Weights.Length;
int channelsOne = mlp._0Bias.Length;
int channelsTwo = mlp._1Bias.Length;
int channelsThree = mlp._2Bias.Length;
StringBuilder biasListZero = toConstructorList(mlp._0Bias);
StringBuilder biasListOne = toConstructorList(mlp._1Bias);
StringBuilder biasListTwo = toConstructorList(mlp._2Bias);

string shaderSource = ViewDependenceNetworkShader.Template;
shaderSource = new Regex("OBJECT_NAME" ).Replace(shaderSource, $"{objName}");
shaderSource = new Regex("NUM_CHANNELS_ZERO" ).Replace(shaderSource, $"{channelsZero}");
shaderSource = new Regex("NUM_CHANNELS_ONE" ).Replace(shaderSource, $"{channelsOne}");
shaderSource = new Regex("NUM_CHANNELS_TWO" ).Replace(shaderSource, $"{channelsTwo}");
shaderSource = new Regex("NUM_CHANNELS_THREE").Replace(shaderSource, $"{channelsThree}");
shaderSource = new Regex("BIAS_LIST_ZERO" ).Replace(shaderSource, $"{biasListZero}");
shaderSource = new Regex("BIAS_LIST_ONE" ).Replace(shaderSource, $"{biasListOne}");
shaderSource = new Regex("BIAS_LIST_TWO" ).Replace(shaderSource, $"{biasListTwo}");

for (int i = 0; i < mlp._0Weights.Length; i++) {
shaderSource = new Regex($"__W0_{i}__").Replace(shaderSource, $"{toConstructorList(mlp._0Weights[i])}");
}
for (int i = 0; i < mlp._1Weights.Length; i++) {
shaderSource = new Regex($"__W1_{i}__").Replace(shaderSource, $"{toConstructorList(mlp._1Weights[i])}");
}
for (int i = 0; i < mlp._2Weights.Length; i++) {
shaderSource = new Regex($"__W2_{i}__").Replace(shaderSource, $"{toConstructorList(mlp._2Weights[i])}");
}

// hack way to flip axes depending on scene
string axisSwizzle = MNeRFSceneExtensions.ToEnum(objName).GetAxisSwizzleString();
shaderSource = new Regex("AXIS_SWIZZLE" ).Replace(shaderSource, $"{axisSwizzle}");
Expand All @@ -556,52 +541,12 @@ private static void CreateShader(string objName, Mlp mlp) {
AssetDatabase.Refresh();
}

private static void CreateWeightTextures(string objName, Mlp mlp) {
Texture2D weightsTexZero = createFloatTextureFromData(mlp._0Weights);
Texture2D weightsTexOne = createFloatTextureFromData(mlp._1Weights);
Texture2D weightsTexTwo = createFloatTextureFromData(mlp._2Weights);
AssetDatabase.CreateAsset(weightsTexZero, GetWeightsAssetPath(objName, 0));
AssetDatabase.CreateAsset(weightsTexOne, GetWeightsAssetPath(objName, 1));
AssetDatabase.CreateAsset(weightsTexTwo, GetWeightsAssetPath(objName, 2));
AssetDatabase.SaveAssets();
}

/// <summary>
/// Creates a float32 texture from an array of floats.
/// </summary>
private static Texture2D createFloatTextureFromData(double[][] weights) {
int width = weights.Length;
int height = weights[0].Length;

Texture2D texture = new Texture2D(width, height, TextureFormat.RFloat, mipChain: false, linear: true);
texture.filterMode = FilterMode.Point;
texture.wrapMode = TextureWrapMode.Clamp;
NativeArray<float> textureData = texture.GetRawTextureData<float>();
FillTexture(textureData, weights);
texture.Apply();

return texture;
}

private static void FillTexture(NativeArray<float> textureData, double[][] data) {
int width = data.Length;
int height = data[0].Length;

for (int co = 0; co < height; co++) {
for (int ci = 0; ci < width; ci++) {
int index = co * width + ci;
double weight = data[ci][co];
textureData[index] = (float)weight;
}
}
}

private static StringBuilder toBiasList(double[] biases) {
private static StringBuilder toConstructorList(double[] list) {
System.Globalization.CultureInfo culture = System.Globalization.CultureInfo.InvariantCulture;
int width = biases.Length;
int width = list.Length;
StringBuilder biasList = new StringBuilder(width * 12);
for (int i = 0; i < width; i++) {
double bias = biases[i];
double bias = list[i];
biasList.Append(bias.ToString("F7", culture));
if (i + 1 < width) {
biasList.Append(", ");
Expand Down
2 changes: 1 addition & 1 deletion Editor/MobileNeRFScene.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public static string GetAxisSwizzleString(this MNeRFScene scene) {
case MNeRFScene.Custom:
// Based on user feedback for custom scenes
if (MobileNeRFImporter.SwizzleAxis) {
return "o.rayDirection.xz = -o.rayDirection.xz;" +
return "o.rayDirection.xz = -o.rayDirection.xz;" + Environment.NewLine +
"o.rayDirection.xyz = o.rayDirection.xzy;";
} else {
return "o.rayDirection.x = -o.rayDirection.x;";
Expand Down
128 changes: 63 additions & 65 deletions Editor/ShaderTemplate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@ public static class ViewDependenceNetworkShader {
Properties {
tDiffuse0x (""Diffuse Texture 0"", 2D) = ""white"" {}
tDiffuse1x (""Diffuse Texture 1"", 2D) = ""white"" {}
weightsZero (""Weights Zero"", 2D) = ""white"" {}
weightsOne (""Weights One"", 2D) = ""white"" {}
weightsTwo (""Weights Two"", 2D) = ""white"" {}
}
CGINCLUDE
Expand Down Expand Up @@ -41,66 +38,68 @@ v2f vert(appdata v) {
sampler2D tDiffuse0x;
sampler2D tDiffuse1x;
sampler2D tDiffuse2x;
UNITY_DECLARE_TEX2D(weightsZero);
UNITY_DECLARE_TEX2D(weightsOne);
UNITY_DECLARE_TEX2D(weightsTwo);
half3 evaluateNetwork(fixed4 f0, fixed4 f1, fixed4 viewdir) {
half intermediate_one[NUM_CHANNELS_ONE] = { BIAS_LIST_ZERO };
int i = 0;
int j = 0;
for (j = 0; j < NUM_CHANNELS_ZERO; ++j) {
half input_value = 0.0;
if (j < 4) {
input_value =
(j == 0) ? f0.r : (
(j == 1) ? f0.g : (
(j == 2) ? f0.b : f0.a));
} else if (j < 8) {
input_value =
(j == 4) ? f1.r : (
(j == 5) ? f1.g : (
(j == 6) ? f1.b : f1.a));
} else {
input_value =
(j == 8) ? viewdir.r : (
(j == 9) ? viewdir.g : viewdir.b);
}
for (i = 0; i < NUM_CHANNELS_ONE; ++i) {
intermediate_one[i] += input_value * weightsZero.Load(int3(j, i, 0)).x;
}
}
half intermediate_two[NUM_CHANNELS_TWO] = { BIAS_LIST_ONE };
for (j = 0; j < NUM_CHANNELS_ONE; ++j) {
if (intermediate_one[j] <= 0.0) {
continue;
}
for (i = 0; i < NUM_CHANNELS_TWO; ++i) {
intermediate_two[i] += intermediate_one[j] * weightsOne.Load(int3(j, i, 0)).x;
}
}
half result[NUM_CHANNELS_THREE] = { BIAS_LIST_TWO };
for (j = 0; j < NUM_CHANNELS_TWO; ++j) {
if (intermediate_two[j] <= 0.0) {
continue;
}
for (i = 0; i < NUM_CHANNELS_THREE; ++i) {
result[i] += intermediate_two[j] * weightsTwo.Load(int3(j, i, 0)).x;
}
}
for (i = 0; i < NUM_CHANNELS_THREE; ++i) {
result[i] = 1.0 / (1.0 + exp(-result[i]));
}
return half3(result[0]*viewdir.a+(1.0-viewdir.a),
result[1]*viewdir.a+(1.0-viewdir.a),
result[2]*viewdir.a+(1.0-viewdir.a));
float4x4 intermediate_one = { BIAS_LIST_ZERO };
intermediate_one += f0.r * float4x4(__W0_0__)
+ f0.g * float4x4(__W0_1__)
+ f0.b * float4x4(__W0_2__)
+ f0.a * float4x4(__W0_3__)
+ f1.r * float4x4(__W0_4__)
+ f1.g * float4x4(__W0_5__)
+ f1.b * float4x4(__W0_6__)
+ f1.a * float4x4(__W0_7__)
+ viewdir.r * float4x4(__W0_8__)
+ viewdir.g * float4x4(__W0_9__)
+ viewdir.b * float4x4(__W0_10__);
intermediate_one[0] = max(intermediate_one[0], 0.0);
intermediate_one[1] = max(intermediate_one[1], 0.0);
intermediate_one[2] = max(intermediate_one[2], 0.0);
intermediate_one[3] = max(intermediate_one[3], 0.0);
float4x4 intermediate_two = float4x4(
BIAS_LIST_ONE
);
intermediate_two += intermediate_one[0][0] * float4x4(__W1_0__)
+ intermediate_one[0][1] * float4x4(__W1_1__)
+ intermediate_one[0][2] * float4x4(__W1_2__)
+ intermediate_one[0][3] * float4x4(__W1_3__)
+ intermediate_one[1][0] * float4x4(__W1_4__)
+ intermediate_one[1][1] * float4x4(__W1_5__)
+ intermediate_one[1][2] * float4x4(__W1_6__)
+ intermediate_one[1][3] * float4x4(__W1_7__)
+ intermediate_one[2][0] * float4x4(__W1_8__)
+ intermediate_one[2][1] * float4x4(__W1_9__)
+ intermediate_one[2][2] * float4x4(__W1_10__)
+ intermediate_one[2][3] * float4x4(__W1_11__)
+ intermediate_one[3][0] * float4x4(__W1_12__)
+ intermediate_one[3][1] * float4x4(__W1_13__)
+ intermediate_one[3][2] * float4x4(__W1_14__)
+ intermediate_one[3][3] * float4x4(__W1_15__);
intermediate_two[0] = max(intermediate_two[0], 0.0);
intermediate_two[1] = max(intermediate_two[1], 0.0);
intermediate_two[2] = max(intermediate_two[2], 0.0);
intermediate_two[3] = max(intermediate_two[3], 0.0);
float3 result = float3(
BIAS_LIST_TWO
);
result += intermediate_two[0][0] * float3(__W2_0__)
+ intermediate_two[0][1] * float3(__W2_1__)
+ intermediate_two[0][2] * float3(__W2_2__)
+ intermediate_two[0][3] * float3(__W2_3__)
+ intermediate_two[1][0] * float3(__W2_4__)
+ intermediate_two[1][1] * float3(__W2_5__)
+ intermediate_two[1][2] * float3(__W2_6__)
+ intermediate_two[1][3] * float3(__W2_7__)
+ intermediate_two[2][0] * float3(__W2_8__)
+ intermediate_two[2][1] * float3(__W2_9__)
+ intermediate_two[2][2] * float3(__W2_10__)
+ intermediate_two[2][3] * float3(__W2_11__)
+ intermediate_two[3][0] * float3(__W2_12__)
+ intermediate_two[3][1] * float3(__W2_13__)
+ intermediate_two[3][2] * float3(__W2_14__)
+ intermediate_two[3][3] * float3(__W2_15__);
result = 1.0 / (1.0 + exp(-result));
return result*viewdir.a+(1.0-viewdir.a);
}
ENDCG
Expand All @@ -120,10 +119,9 @@ fixed4 frag(v2f i) : SV_Target {
fixed4 diffuse1 = tex2D( tDiffuse1x, i.uv );
fixed4 rayDir = fixed4(normalize(i.rayDirection), 1.0);
//deal with iphone
diffuse0.a = diffuse0.a*2.0-1.0;
diffuse1.a = diffuse1.a*2.0-1.0;
rayDir.a = rayDir.a*2.0-1.0;
// normalize range to [-1, 1]
diffuse0.a = diffuse0.a * 2.0 - 1.0;
diffuse1.a = diffuse1.a * 2.0 - 1.0;
fixed4 fragColor;
fragColor.rgb = evaluateNetwork(diffuse0,diffuse1,rayDir);
Expand Down
4 changes: 4 additions & 0 deletions Editor/WebRequestAsyncUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,11 @@ private static UnityWebRequest GetRequest(string url, HTTPVerb verb, string post
webRequest = UnityWebRequest.Get(url);
break;
case HTTPVerb.POST:
#if UNITY_2022_2_OR_NEWER
webRequest = UnityWebRequest.PostWwwForm(url, postData);
#else
webRequest = UnityWebRequest.Post(url, postData);
#endif
byte[] rawBody = Encoding.UTF8.GetBytes(postData);
webRequest.uploadHandler = new UploadHandlerRaw(rawBody);
webRequest.downloadHandler = new DownloadHandlerBuffer();
Expand Down

0 comments on commit dbd39c1

Please sign in to comment.