Instruction.Barrier

Whoops

Fix inline functions in compute stage

Fix regression

Declare SharedMemories + Only Declare Memories on Main Func

Lowecase struct

Avoid magic strings

Make function signatures readable

Change how unsized arrays are indexed

Use string builder

Fix shuffle instructions

Cleanup NumberFormater

Bunch of Subgroup I/O Vars

Will probably need further refinement

Fix point_coord type

Fix support buffer declaration

Fix point_coord
This commit is contained in:
Isaac Marovitz 2024-06-21 10:31:21 +01:00 committed by Evan Husted
parent d0e4adac36
commit becf828d0a
10 changed files with 110 additions and 77 deletions

View File

@ -9,7 +9,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
public const string Tab = " ";
// The number of additional arguments that every function (except for the main one) must have (for instance support_buffer)
public const int AdditionalArgCount = 2;
public const int AdditionalArgCount = 1;
public StructuredFunction CurrentFunction { get; set; }

View File

@ -64,9 +64,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
return ioDefinition.StorageKind == storageKind && ioDefinition.IoVariable == IoVariable.UserDefined;
}
public static void DeclareLocals(CodeGenContext context, StructuredFunction function, ShaderStage stage)
public static void DeclareLocals(CodeGenContext context, StructuredFunction function, ShaderStage stage, bool isMainFunc = false)
{
DeclareMemories(context, context.Properties.LocalMemories.Values, isShared: false);
if (isMainFunc)
{
DeclareMemories(context, context.Properties.LocalMemories.Values, isShared: false);
DeclareMemories(context, context.Properties.SharedMemories.Values, isShared: true);
}
switch (stage)
{
case ShaderStage.Vertex:
@ -112,6 +117,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
private static void DeclareMemories(CodeGenContext context, IEnumerable<MemoryDefinition> memories, bool isShared)
{
string prefix = isShared ? "threadgroup " : string.Empty;
foreach (var memory in memories)
{
string arraySize = "";
@ -120,7 +127,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
arraySize = $"[{memory.ArrayLength}]";
}
var typeName = GetVarTypeName(context, memory.Type & ~AggregateType.Array);
context.AppendLine($"{typeName} {memory.Name}{arraySize};");
context.AppendLine($"{prefix}{typeName} {memory.Name}{arraySize};");
}
}
@ -128,23 +135,28 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{
foreach (BufferDefinition buffer in buffers)
{
context.AppendLine($"struct Struct_{buffer.Name}");
context.AppendLine($"struct {DefaultNames.StructPrefix}_{buffer.Name}");
context.EnterScope();
foreach (StructureField field in buffer.Type.Fields)
{
if (field.Type.HasFlag(AggregateType.Array) && field.ArrayLength > 0)
{
string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array);
string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array);
string arraySuffix = "";
context.AppendLine($"{typeName} {field.Name}[{field.ArrayLength}];");
}
else
if (field.Type.HasFlag(AggregateType.Array))
{
string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array);
context.AppendLine($"{typeName} {field.Name};");
if (field.ArrayLength > 0)
{
arraySuffix = $"[{field.ArrayLength}]";
}
else
{
// Probably UB, but this is the approach that MVK takes
arraySuffix = "[1]";
}
}
context.AppendLine($"{typeName} {field.Name}{arraySuffix};");
}
context.LeaveScope(";");
@ -191,6 +203,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
IoVariable.GlobalId => "uint3",
IoVariable.VertexId => "uint",
IoVariable.VertexIndex => "uint",
IoVariable.PointCoord => "float2",
_ => GetVarTypeName(context, context.Definitions.GetUserDefinedType(ioDefinition.Location, isOutput: false))
};
string name = ioDefinition.IoVariable switch
@ -199,6 +212,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
IoVariable.GlobalId => "global_id",
IoVariable.VertexId => "vertex_id",
IoVariable.VertexIndex => "vertex_index",
IoVariable.PointCoord => "point_coord",
_ => $"{DefaultNames.IAttributePrefix}{ioDefinition.Location}"
};
string suffix = ioDefinition.IoVariable switch
@ -208,6 +222,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
IoVariable.VertexId => "[[vertex_id]]",
// TODO: Avoid potential redeclaration
IoVariable.VertexIndex => "[[vertex_id]]",
IoVariable.PointCoord => "[[point_coord]]",
IoVariable.UserDefined => context.Definitions.Stage == ShaderStage.Fragment ? $"[[user(loc{ioDefinition.Location})]]" : $"[[attribute({ioDefinition.Location})]]",
_ => ""
};

View File

@ -8,6 +8,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
public const string IAttributePrefix = "inAttr";
public const string OAttributePrefix = "outAttr";
public const string StructPrefix = "struct";
public const string ArgumentNamePrefix = "a";
public const string UndefinedName = "0";

View File

@ -2,7 +2,7 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using Ryujinx.Graphics.Shader.StructuredIr;
using Ryujinx.Graphics.Shader.Translation;
using System;
using System.Text;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenCall;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenMemory;
@ -39,11 +39,20 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
int arity = (int)(info.Type & InstType.ArityMask);
string args = string.Empty;
StringBuilder builder = new();
if (atomic)
if (atomic && (operation.StorageKind == StorageKind.StorageBuffer || operation.StorageKind == StorageKind.SharedMemory))
{
// Hell
builder.Append(GenerateLoadOrStore(context, operation, isStore: false));
AggregateType dstType = operation.Inst == Instruction.AtomicMaxS32 || operation.Inst == Instruction.AtomicMinS32
? AggregateType.S32
: AggregateType.U32;
for (int argIndex = operation.SourcesCount - arity + 2; argIndex < operation.SourcesCount; argIndex++)
{
builder.Append($", {GetSourceExpr(context, operation.GetSource(argIndex), dstType)}");
}
}
else
{
@ -51,16 +60,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
{
if (argIndex != 0)
{
args += ", ";
builder.Append(", ");
}
AggregateType dstType = GetSrcVarType(inst, argIndex);
args += GetSourceExpr(context, operation.GetSource(argIndex), dstType);
builder.Append(GetSourceExpr(context, operation.GetSource(argIndex), dstType));
}
}
return info.OpName + '(' + args + ')';
return $"{info.OpName}({builder})";
}
else if ((info.Type & InstType.Op) != 0)
{
@ -110,7 +119,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
switch (inst & Instruction.Mask)
{
case Instruction.Barrier:
return "|| BARRIER ||";
return "threadgroup_barrier(mem_flags::mem_threadgroup)";
case Instruction.Call:
return Call(context, operation);
case Instruction.FSIBegin:

View File

@ -13,13 +13,22 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
var functon = context.GetFunction(funcId.Value);
int argCount = operation.SourcesCount - 1;
string[] args = new string[argCount + CodeGenContext.AdditionalArgCount];
int additionalArgCount = CodeGenContext.AdditionalArgCount + (context.Definitions.Stage != ShaderStage.Compute ? 1 : 0);
string[] args = new string[argCount + additionalArgCount];
// Additional arguments
args[0] = "in";
args[1] = "support_buffer";
if (context.Definitions.Stage != ShaderStage.Compute)
{
args[0] = "in";
args[1] = "support_buffer";
}
else
{
args[0] = "support_buffer";
}
int argIndex = CodeGenContext.AdditionalArgCount;
int argIndex = additionalArgCount;
for (int i = 0; i < argCount; i++)
{
args[argIndex++] = GetSourceExpr(context, operation.GetSource(i + 1), functon.GetArgumentType(i));

View File

@ -109,10 +109,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
Add(Instruction.ShiftLeft, InstType.OpBinary, "<<", 3);
Add(Instruction.ShiftRightS32, InstType.OpBinary, ">>", 3);
Add(Instruction.ShiftRightU32, InstType.OpBinary, ">>", 3);
Add(Instruction.Shuffle, InstType.CallQuaternary, "simd_shuffle");
Add(Instruction.ShuffleDown, InstType.CallQuaternary, "simd_shuffle_down");
Add(Instruction.ShuffleUp, InstType.CallQuaternary, "simd_shuffle_up");
Add(Instruction.ShuffleXor, InstType.CallQuaternary, "simd_shuffle_xor");
Add(Instruction.Shuffle, InstType.CallBinary, "simd_shuffle");
Add(Instruction.ShuffleDown, InstType.CallBinary, "simd_shuffle_down");
Add(Instruction.ShuffleUp, InstType.CallBinary, "simd_shuffle_up");
Add(Instruction.ShuffleXor, InstType.CallBinary, "simd_shuffle_xor");
Add(Instruction.Sine, InstType.CallUnary, "sin");
Add(Instruction.SquareRoot, InstType.CallUnary, "sqrt");
Add(Instruction.Store, InstType.Special);

View File

@ -47,15 +47,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
StructureField field = buffer.Type.Fields[fieldIndex.Value];
varName = buffer.Name;
if ((field.Type & AggregateType.Array) != 0 && field.ArrayLength == 0)
{
// Unsized array, the buffer is indexed instead of the field
fieldName = "." + field.Name;
}
else
{
varName += "->" + field.Name;
}
varName += "->" + field.Name;
varType = field.Type;
break;

View File

@ -27,13 +27,19 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
IoVariable.GlobalId => ("thread_position_in_grid", AggregateType.Vector3 | AggregateType.U32),
IoVariable.InstanceId => ("instance_id", AggregateType.S32),
IoVariable.InvocationId => ("INVOCATION_ID", AggregateType.S32),
IoVariable.PointCoord => ("point_coord", AggregateType.Vector2),
IoVariable.PointCoord => ("point_coord", AggregateType.Vector2 | AggregateType.FP32),
IoVariable.PointSize => ("out.point_size", AggregateType.FP32),
IoVariable.Position => ("out.position", AggregateType.Vector4 | AggregateType.FP32),
IoVariable.PrimitiveId => ("primitive_id", AggregateType.S32),
IoVariable.SubgroupEqMask => ("thread_index_in_simdgroup >= 32 ? uint4(0, (1 << (thread_index_in_simdgroup - 32)), uint2(0)) : uint4(1 << thread_index_in_simdgroup, uint3(0))", AggregateType.Vector4 | AggregateType.U32),
IoVariable.SubgroupGeMask => ("uint4(insert_bits(0u, 0xFFFFFFFF, thread_index_in_simdgroup, 32 - thread_index_in_simdgroup), uint3(0)) & (uint4((uint)((simd_vote::vote_t)simd_ballot(true) & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)simd_ballot(true) >> 32) & 0xFFFFFFFF), 0, 0))", AggregateType.Vector4 | AggregateType.U32),
IoVariable.SubgroupGtMask => ("uint4(insert_bits(0u, 0xFFFFFFFF, thread_index_in_simdgroup + 1, 32 - thread_index_in_simdgroup - 1), uint3(0)) & (uint4((uint)((simd_vote::vote_t)simd_ballot(true) & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)simd_ballot(true) >> 32) & 0xFFFFFFFF), 0, 0))", AggregateType.Vector4 | AggregateType.U32),
IoVariable.SubgroupLaneId => ("thread_index_in_simdgroup", AggregateType.U32),
IoVariable.SubgroupLeMask => ("uint4(extract_bits(0xFFFFFFFF, 0, min(thread_index_in_simdgroup + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)thread_index_in_simdgroup + 1 - 32, 0)), uint2(0))", AggregateType.Vector4 | AggregateType.U32),
IoVariable.SubgroupLtMask => ("uint4(extract_bits(0xFFFFFFFF, 0, min(thread_index_in_simdgroup, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)thread_index_in_simdgroup - 32, 0)), uint2(0))", AggregateType.Vector4 | AggregateType.U32),
IoVariable.ThreadKill => ("simd_is_helper_thread()", AggregateType.Bool),
IoVariable.UserDefined => GetUserDefinedVariableName(definitions, location, component, isOutput, isPerPatch),
IoVariable.ThreadId => ("thread_position_in_threadgroup", AggregateType.Vector3 | AggregateType.U32),
IoVariable.SubgroupLaneId => ("thread_index_in_simdgroup", AggregateType.U32),
IoVariable.VertexId => ("vertex_id", AggregateType.S32),
// gl_VertexIndex does not have a direct equivalent in MSL
IoVariable.VertexIndex => ("vertex_id", AggregateType.U32),

View File

@ -44,7 +44,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
context.AppendLine(GetFunctionSignature(context, function, stage, isMainFunc));
context.EnterScope();
Declarations.DeclareLocals(context, function, stage);
Declarations.DeclareLocals(context, function, stage, isMainFunc);
PrintBlock(context, function.MainBlock, isMainFunc);
@ -63,15 +63,22 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
ShaderStage stage,
bool isMainFunc = false)
{
int additionalArgCount = isMainFunc ? 0 : CodeGenContext.AdditionalArgCount;
int additionalArgCount = isMainFunc ? 0 : CodeGenContext.AdditionalArgCount + (context.Definitions.Stage != ShaderStage.Compute ? 1 : 0);
string[] args = new string[additionalArgCount + function.InArguments.Length + function.OutArguments.Length];
// All non-main functions need to be able to access the support_buffer as well
if (!isMainFunc)
{
args[0] = "FragmentIn in";
args[1] = "constant Struct_support_buffer* support_buffer";
if (stage != ShaderStage.Compute)
{
args[0] = stage == ShaderStage.Vertex ? "VertexIn in" : "FragmentIn in";
args[1] = $"constant {DefaultNames.StructPrefix}_support_buffer* support_buffer";
}
else
{
args[0] = $"constant {DefaultNames.StructPrefix}_support_buffer* support_buffer";
}
}
int argIndex = additionalArgCount;
@ -141,13 +148,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
foreach (var constantBuffer in context.Properties.ConstantBuffers.Values)
{
args = args.Append($"constant Struct_{constantBuffer.Name}* {constantBuffer.Name} [[buffer({constantBuffer.Binding})]]").ToArray();
args = args.Append($"constant {DefaultNames.StructPrefix}_{constantBuffer.Name}* {constantBuffer.Name} [[buffer({constantBuffer.Binding})]]").ToArray();
}
foreach (var storageBuffers in context.Properties.StorageBuffers.Values)
{
// Offset the binding by 15 to avoid clashing with the constant buffers
args = args.Append($"device Struct_{storageBuffers.Name}* {storageBuffers.Name} [[buffer({storageBuffers.Binding + 15})]]").ToArray();
args = args.Append($"device {DefaultNames.StructPrefix}_{storageBuffers.Name}* {storageBuffers.Name} [[buffer({storageBuffers.Binding + 15})]]").ToArray();
}
foreach (var texture in context.Properties.Textures.Values)
@ -162,7 +169,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
}
}
return $"{funcKeyword} {returnType} {funcName ?? function.Name}({string.Join(", ", args)})";
var funcPrefix = $"{funcKeyword} {returnType} {funcName ?? function.Name}(";
var indent = new string(' ', funcPrefix.Length);
return $"{funcPrefix}{string.Join($", \n{indent}", args)})";
}
private static void PrintBlock(CodeGenContext context, AstBlock block, bool isMainFunction)

View File

@ -10,25 +10,21 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
public static bool TryFormat(int value, AggregateType dstType, out string formatted)
{
if (dstType == AggregateType.FP32)
switch (dstType)
{
return TryFormatFloat(BitConverter.Int32BitsToSingle(value), out formatted);
}
else if (dstType == AggregateType.S32)
{
formatted = FormatInt(value);
}
else if (dstType == AggregateType.U32)
{
formatted = FormatUint((uint)value);
}
else if (dstType == AggregateType.Bool)
{
formatted = value != 0 ? "true" : "false";
}
else
{
throw new ArgumentException($"Invalid variable type \"{dstType}\".");
case AggregateType.FP32:
return TryFormatFloat(BitConverter.Int32BitsToSingle(value), out formatted);
case AggregateType.S32:
formatted = FormatInt(value);
break;
case AggregateType.U32:
formatted = FormatUint((uint)value);
break;
case AggregateType.Bool:
formatted = value != 0 ? "true" : "false";
break;
default:
throw new ArgumentException($"Invalid variable type \"{dstType}\".");
}
return true;
@ -65,18 +61,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
public static string FormatInt(int value, AggregateType dstType)
{
if (dstType == AggregateType.S32)
return dstType switch
{
return FormatInt(value);
}
else if (dstType == AggregateType.U32)
{
return FormatUint((uint)value);
}
else
{
throw new ArgumentException($"Invalid variable type \"{dstType}\".");
}
AggregateType.S32 => FormatInt(value),
AggregateType.U32 => FormatUint((uint)value),
_ => throw new ArgumentException($"Invalid variable type \"{dstType}\".")
};
}
public static string FormatInt(int value)