diff --git a/ptx/src/test/spirv_run/ntid.spvtxt b/ptx/src/test/spirv_run/ntid.spvtxt index d1a3caaf..acd69d7f 100644 --- a/ptx/src/test/spirv_run/ntid.spvtxt +++ b/ptx/src/test/spirv_run/ntid.spvtxt @@ -12,16 +12,16 @@ OpEntryPoint Kernel %1 "ntid" %gl_WorkGroupSize OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize %void = OpTypeVoid - %uint = OpTypeInt 32 0 - %v4uint = OpTypeVector %uint 4 -%_ptr_UniformConstant_v4uint = OpTypePointer UniformConstant %v4uint -%gl_WorkGroupSize = OpVariable %_ptr_UniformConstant_v4uint UniformConstant %ulong = OpTypeInt 64 0 - %33 = OpTypeFunction %void %ulong %ulong + %v3ulong = OpTypeVector %ulong 3 +%_ptr_Input_v3ulong = OpTypePointer Input %v3ulong +%gl_WorkGroupSize = OpVariable %_ptr_Input_v3ulong Input + %32 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint - %1 = OpFunction %void None %33 + %1 = OpFunction %void None %32 %9 = OpFunctionParameter %ulong %10 = OpFunctionParameter %ulong %25 = OpLabel @@ -41,9 +41,11 @@ %23 = OpConvertUToPtr %_ptr_Generic_uint %14 %13 = OpLoad %uint %23 OpStore %6 %13 - %16 = OpLoad %v4uint %gl_WorkGroupSize - %22 = OpCompositeExtract %uint %16 0 - %15 = OpCopyObject %uint %22 + %37 = OpLoad %v3ulong %gl_WorkGroupSize + %22 = OpCompositeExtract %ulong %37 0 + %38 = OpBitcast %ulong %22 + %16 = OpUConvert %uint %38 + %15 = OpCopyObject %uint %16 OpStore %7 %15 %18 = OpLoad %uint %6 %19 = OpLoad %uint %7 diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt index c53ad516..bad44f42 100644 --- a/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt +++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt @@ -12,16 +12,16 @@ OpEntryPoint Kernel %1 "stateful_ld_st_ntid" %gl_LocalInvocationID OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId %void = OpTypeVoid - %uint = OpTypeInt 32 0 - %v4uint = OpTypeVector %uint 4 -%_ptr_Input_v4uint = OpTypePointer Input %v4uint -%gl_LocalInvocationID = OpVariable %_ptr_Input_v4uint Input + %ulong = OpTypeInt 64 0 + %v3ulong = OpTypeVector %ulong 3 +%_ptr_Input_v3ulong = OpTypePointer Input %v3ulong +%gl_LocalInvocationID = OpVariable %_ptr_Input_v3ulong Input %uchar = OpTypeInt 8 0 %_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar %56 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar %_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar + %uint = OpTypeInt 32 0 %_ptr_Function_uint = OpTypePointer Function %uint - %ulong = OpTypeInt 64 0 %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong %1 = OpFunction %void None %56 @@ -57,25 +57,27 @@ %18 = OpCopyObject %ulong %19 %27 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %18 OpStore %11 %27 - %29 = OpLoad %v4uint %gl_LocalInvocationID - %42 = OpCompositeExtract %uint %29 0 - %28 = OpCopyObject %uint %42 + %61 = OpLoad %v3ulong %gl_LocalInvocationID + %42 = OpCompositeExtract %ulong %61 0 + %62 = OpBitcast %ulong %42 + %29 = OpUConvert %uint %62 + %28 = OpCopyObject %uint %29 OpStore %6 %28 %31 = OpLoad %uint %6 - %61 = OpBitcast %uint %31 - %30 = OpUConvert %ulong %61 + %63 = OpBitcast %uint %31 + %30 = OpUConvert %ulong %63 OpStore %7 %30 %33 = OpLoad %_ptr_CrossWorkgroup_uchar %10 %34 = OpLoad %ulong %7 - %62 = OpBitcast %_ptr_CrossWorkgroup_uchar %33 - %63 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %62 %34 - %32 = OpBitcast %_ptr_CrossWorkgroup_uchar %63 + %64 = OpBitcast %_ptr_CrossWorkgroup_uchar %33 + %65 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %64 %34 + %32 = OpBitcast %_ptr_CrossWorkgroup_uchar %65 OpStore %10 %32 %36 = OpLoad %_ptr_CrossWorkgroup_uchar %11 %37 = OpLoad %ulong %7 - %64 = OpBitcast %_ptr_CrossWorkgroup_uchar %36 - %65 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %64 %37 - %35 = OpBitcast %_ptr_CrossWorkgroup_uchar %65 + %66 = OpBitcast %_ptr_CrossWorkgroup_uchar %36 + %67 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %66 %37 + %35 = OpBitcast %_ptr_CrossWorkgroup_uchar %67 OpStore %11 %35 %39 = OpLoad %_ptr_CrossWorkgroup_uchar %10 %45 = OpBitcast %_ptr_CrossWorkgroup_ulong %39 diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt index 5ba889c3..cc99aa0f 100644 --- a/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt +++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt @@ -12,16 +12,16 @@ OpEntryPoint Kernel %1 "stateful_ld_st_ntid_chain" %gl_LocalInvocationID OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId %void = OpTypeVoid - %uint = OpTypeInt 32 0 - %v4uint = OpTypeVector %uint 4 -%_ptr_Input_v4uint = OpTypePointer Input %v4uint -%gl_LocalInvocationID = OpVariable %_ptr_Input_v4uint Input + %ulong = OpTypeInt 64 0 + %v3ulong = OpTypeVector %ulong 3 +%_ptr_Input_v3ulong = OpTypePointer Input %v3ulong +%gl_LocalInvocationID = OpVariable %_ptr_Input_v3ulong Input %uchar = OpTypeInt 8 0 %_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar %64 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar %_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar + %uint = OpTypeInt 32 0 %_ptr_Function_uint = OpTypePointer Function %uint - %ulong = OpTypeInt 64 0 %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong %1 = OpFunction %void None %64 @@ -61,25 +61,27 @@ %26 = OpCopyObject %ulong %27 %35 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %26 OpStore %18 %35 - %37 = OpLoad %v4uint %gl_LocalInvocationID - %50 = OpCompositeExtract %uint %37 0 - %36 = OpCopyObject %uint %50 + %69 = OpLoad %v3ulong %gl_LocalInvocationID + %50 = OpCompositeExtract %ulong %69 0 + %70 = OpBitcast %ulong %50 + %37 = OpUConvert %uint %70 + %36 = OpCopyObject %uint %37 OpStore %10 %36 %39 = OpLoad %uint %10 - %69 = OpBitcast %uint %39 - %38 = OpUConvert %ulong %69 + %71 = OpBitcast %uint %39 + %38 = OpUConvert %ulong %71 OpStore %11 %38 %41 = OpLoad %_ptr_CrossWorkgroup_uchar %15 %42 = OpLoad %ulong %11 - %70 = OpBitcast %_ptr_CrossWorkgroup_uchar %41 - %71 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %70 %42 - %40 = OpBitcast %_ptr_CrossWorkgroup_uchar %71 + %72 = OpBitcast %_ptr_CrossWorkgroup_uchar %41 + %73 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %72 %42 + %40 = OpBitcast %_ptr_CrossWorkgroup_uchar %73 OpStore %16 %40 %44 = OpLoad %_ptr_CrossWorkgroup_uchar %18 %45 = OpLoad %ulong %11 - %72 = OpBitcast %_ptr_CrossWorkgroup_uchar %44 - %73 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %72 %45 - %43 = OpBitcast %_ptr_CrossWorkgroup_uchar %73 + %74 = OpBitcast %_ptr_CrossWorkgroup_uchar %44 + %75 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %74 %45 + %43 = OpBitcast %_ptr_CrossWorkgroup_uchar %75 OpStore %19 %43 %47 = OpLoad %_ptr_CrossWorkgroup_uchar %16 %53 = OpBitcast %_ptr_CrossWorkgroup_ulong %47 diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt index 3c215d46..32f2afba 100644 --- a/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt +++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt @@ -12,16 +12,16 @@ OpEntryPoint Kernel %1 "stateful_ld_st_ntid_sub" %gl_LocalInvocationID OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId %void = OpTypeVoid - %uint = OpTypeInt 32 0 - %v4uint = OpTypeVector %uint 4 -%_ptr_Input_v4uint = OpTypePointer Input %v4uint -%gl_LocalInvocationID = OpVariable %_ptr_Input_v4uint Input + %ulong = OpTypeInt 64 0 + %v3ulong = OpTypeVector %ulong 3 +%_ptr_Input_v3ulong = OpTypePointer Input %v3ulong +%gl_LocalInvocationID = OpVariable %_ptr_Input_v3ulong Input %uchar = OpTypeInt 8 0 %_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar %72 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar %_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar + %uint = OpTypeInt 32 0 %_ptr_Function_uint = OpTypePointer Function %uint - %ulong = OpTypeInt 64 0 %_ptr_Function_ulong = OpTypePointer Function %ulong %ulong_0 = OpConstant %ulong 0 %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong @@ -63,43 +63,45 @@ %26 = OpCopyObject %ulong %27 %37 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %26 OpStore %18 %37 - %39 = OpLoad %v4uint %gl_LocalInvocationID - %52 = OpCompositeExtract %uint %39 0 - %38 = OpCopyObject %uint %52 + %77 = OpLoad %v3ulong %gl_LocalInvocationID + %52 = OpCompositeExtract %ulong %77 0 + %78 = OpBitcast %ulong %52 + %39 = OpUConvert %uint %78 + %38 = OpCopyObject %uint %39 OpStore %10 %38 %41 = OpLoad %uint %10 - %77 = OpBitcast %uint %41 - %40 = OpUConvert %ulong %77 + %79 = OpBitcast %uint %41 + %40 = OpUConvert %ulong %79 OpStore %11 %40 %42 = OpLoad %ulong %11 %59 = OpCopyObject %ulong %42 %28 = OpSNegate %ulong %59 %44 = OpLoad %_ptr_CrossWorkgroup_uchar %15 - %78 = OpBitcast %_ptr_CrossWorkgroup_uchar %44 - %79 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %78 %28 - %43 = OpBitcast %_ptr_CrossWorkgroup_uchar %79 + %80 = OpBitcast %_ptr_CrossWorkgroup_uchar %44 + %81 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %80 %28 + %43 = OpBitcast %_ptr_CrossWorkgroup_uchar %81 OpStore %16 %43 %45 = OpLoad %ulong %11 %60 = OpCopyObject %ulong %45 %29 = OpSNegate %ulong %60 %47 = OpLoad %_ptr_CrossWorkgroup_uchar %18 - %80 = OpBitcast %_ptr_CrossWorkgroup_uchar %47 - %81 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %80 %29 - %46 = OpBitcast %_ptr_CrossWorkgroup_uchar %81 + %82 = OpBitcast %_ptr_CrossWorkgroup_uchar %47 + %83 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %82 %29 + %46 = OpBitcast %_ptr_CrossWorkgroup_uchar %83 OpStore %19 %46 %49 = OpLoad %_ptr_CrossWorkgroup_uchar %16 %61 = OpBitcast %_ptr_CrossWorkgroup_ulong %49 - %83 = OpBitcast %_ptr_CrossWorkgroup_uchar %61 - %84 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %83 %ulong_0 - %54 = OpBitcast %_ptr_CrossWorkgroup_ulong %84 + %85 = OpBitcast %_ptr_CrossWorkgroup_uchar %61 + %86 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %85 %ulong_0 + %54 = OpBitcast %_ptr_CrossWorkgroup_ulong %86 %48 = OpLoad %ulong %54 OpStore %12 %48 %50 = OpLoad %_ptr_CrossWorkgroup_uchar %19 %51 = OpLoad %ulong %12 %62 = OpBitcast %_ptr_CrossWorkgroup_ulong %50 - %85 = OpBitcast %_ptr_CrossWorkgroup_uchar %62 - %86 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %85 %ulong_0_0 - %56 = OpBitcast %_ptr_CrossWorkgroup_ulong %86 + %87 = OpBitcast %_ptr_CrossWorkgroup_uchar %62 + %88 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %87 %ulong_0_0 + %56 = OpBitcast %_ptr_CrossWorkgroup_ulong %88 OpStore %56 %51 OpReturn OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 20578ebe..dac31a5f 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1030,7 +1030,7 @@ fn emit_builtins( map: &mut TypeWordMap, id_defs: &GlobalStringIdResolver, ) { - for (reg, id) in id_defs.special_registers.iter() { + for (reg, id) in id_defs.special_registers.builtins() { let result_type = map.get_or_add( builder, SpirvType::Pointer( @@ -1038,9 +1038,9 @@ fn emit_builtins( spirv::StorageClass::Input, ), ); - builder.variable(result_type, Some(*id), spirv::StorageClass::Input, None); + builder.variable(result_type, Some(id), spirv::StorageClass::Input, None); builder.decorate( - *id, + id, spirv::Decoration::BuiltIn, &[dr::Operand::BuiltIn(reg.get_builtin())], ); @@ -1086,11 +1086,7 @@ fn emit_function_header<'a>( .iter() .filter_map(|(k, t)| t.as_ref().map(|_| *k)) .collect::>(); - let mut interface = defined_globals - .special_registers - .iter() - .map(|(_, id)| *id) - .collect::>(); + let mut interface = defined_globals.special_registers.interface(); for ast::Variable { name, .. } in synthetic_globals { interface.push(*name); } @@ -1326,6 +1322,7 @@ fn to_ssa<'input, 'b>( &f_args, &mut spirv_decl, )?; + let ssa_statements = fix_builtins(ssa_statements, &mut numeric_id_defs)?; let mut numeric_id_defs = numeric_id_defs.finish(); let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?; let expanded_statements = @@ -1343,6 +1340,87 @@ fn to_ssa<'input, 'b>( }) } +fn fix_builtins( + typed_statements: Vec, + numeric_id_defs: &mut NumericIdResolver, +) -> Result, TranslateError> { + let mut result = Vec::with_capacity(typed_statements.len()); + for s in typed_statements { + match s { + Statement::LoadVar( + mut + details + @ + LoadVarDetails { + member_index: Some((_, Some(_))), + .. + }, + ) => { + let index = details.member_index.unwrap().0; + if index == 3 { + result.push(Statement::Constant(ConstantDefinition { + dst: details.arg.dst, + typ: ast::ScalarType::U32, + value: ast::ImmediateValue::U64(0), + })); + } else { + let src_type = match numeric_id_defs.special_registers.get(details.arg.src) { + Some(reg) => get_sreg_id_scalar_type(numeric_id_defs, reg), + None => None, + }; + let (sreg_src, scalar_typ, vector_width) = match src_type { + Some(x) => x, + None => { + result.push(Statement::LoadVar(details)); + continue; + } + }; + let temp_id = numeric_id_defs.new_non_variable(Some(details.typ.clone())); + let real_dst = details.arg.dst; + details.arg.dst = temp_id; + result.push(Statement::LoadVar(LoadVarDetails { + arg: Arg2 { + src: sreg_src, + dst: temp_id, + }, + typ: ast::Type::Scalar(scalar_typ), + member_index: Some((index, Some(vector_width))), + })); + result.push(Statement::Conversion(ImplicitConversion { + src: temp_id, + dst: real_dst, + from: ast::Type::Scalar(scalar_typ), + to: ast::Type::Scalar(ast::ScalarType::U32), + kind: ConversionKind::Default, + src_sema: ArgumentSemantics::Default, + dst_sema: ArgumentSemantics::Default, + })); + } + } + s => result.push(s), + } + } + Ok(result) +} + +fn get_sreg_id_scalar_type( + numeric_id_defs: &mut NumericIdResolver, + sreg: PtxSpecialRegister, +) -> Option<(spirv::Word, ast::ScalarType, u8)> { + match sreg.normalized_sreg_and_type() { + Some((normalized_sreg, typ, vec_width)) => Some(( + numeric_id_defs.special_registers.replace( + numeric_id_defs.current_id, + sreg, + normalized_sreg, + ), + typ, + vec_width, + )), + None => None, + } +} + fn extract_globals<'input, 'b>( sorted_statements: Vec, ptx_impl_imports: &mut HashMap, @@ -2058,7 +2136,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { }; Some(( idx, - if self.id_def.special_registers.contains_key(&symbol) { + if self.id_def.special_registers.get(symbol).is_some() { Some(vector_width) } else { None @@ -4599,9 +4677,13 @@ fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool { #[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)] enum PtxSpecialRegister { Tid, + Tid64, Ntid, + Ntid64, Ctaid, + Ctaid64, Nctaid, + Nctaid64, } impl PtxSpecialRegister { @@ -4618,18 +4700,110 @@ impl PtxSpecialRegister { fn get_type(self) -> ast::Type { match self { PtxSpecialRegister::Tid => ast::Type::Vector(ast::ScalarType::U32, 4), + PtxSpecialRegister::Tid64 => ast::Type::Vector(ast::ScalarType::U64, 3), PtxSpecialRegister::Ntid => ast::Type::Vector(ast::ScalarType::U32, 4), + PtxSpecialRegister::Ntid64 => ast::Type::Vector(ast::ScalarType::U64, 3), PtxSpecialRegister::Ctaid => ast::Type::Vector(ast::ScalarType::U32, 4), + PtxSpecialRegister::Ctaid64 => ast::Type::Vector(ast::ScalarType::U64, 3), PtxSpecialRegister::Nctaid => ast::Type::Vector(ast::ScalarType::U32, 4), + PtxSpecialRegister::Nctaid64 => ast::Type::Vector(ast::ScalarType::U64, 3), } } fn get_builtin(self) -> spirv::BuiltIn { match self { - PtxSpecialRegister::Tid => spirv::BuiltIn::LocalInvocationId, - PtxSpecialRegister::Ntid => spirv::BuiltIn::WorkgroupSize, - PtxSpecialRegister::Ctaid => spirv::BuiltIn::WorkgroupId, - PtxSpecialRegister::Nctaid => spirv::BuiltIn::NumWorkgroups, + PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => { + spirv::BuiltIn::LocalInvocationId + } + PtxSpecialRegister::Ntid | PtxSpecialRegister::Ntid64 => spirv::BuiltIn::WorkgroupSize, + PtxSpecialRegister::Ctaid | PtxSpecialRegister::Ctaid64 => spirv::BuiltIn::WorkgroupId, + PtxSpecialRegister::Nctaid | PtxSpecialRegister::Nctaid64 => { + spirv::BuiltIn::NumWorkgroups + } + } + } + + fn normalized_sreg_and_type(self) -> Option<(PtxSpecialRegister, ast::ScalarType, u8)> { + match self { + PtxSpecialRegister::Tid => Some((PtxSpecialRegister::Tid64, ast::ScalarType::U64, 3)), + PtxSpecialRegister::Ntid => Some((PtxSpecialRegister::Ntid64, ast::ScalarType::U64, 3)), + PtxSpecialRegister::Ctaid => { + Some((PtxSpecialRegister::Ctaid64, ast::ScalarType::U64, 3)) + } + PtxSpecialRegister::Nctaid => { + Some((PtxSpecialRegister::Nctaid64, ast::ScalarType::U64, 3)) + } + PtxSpecialRegister::Tid64 + | PtxSpecialRegister::Ntid64 + | PtxSpecialRegister::Ctaid64 + | PtxSpecialRegister::Nctaid64 => None, + } + } +} + +struct SpecialRegistersMap { + reg_to_id: HashMap, + id_to_reg: HashMap, +} + +impl SpecialRegistersMap { + fn new() -> Self { + SpecialRegistersMap { + reg_to_id: HashMap::new(), + id_to_reg: HashMap::new(), + } + } + + fn builtins<'a>(&'a self) -> impl Iterator + 'a { + self.reg_to_id.iter().map(|(reg, id)| (*reg, *id)) + } + + fn interface(&self) -> Vec { + self.id_to_reg.iter().map(|(id, _)| *id).collect::>() + } + + fn get(&self, id: spirv::Word) -> Option { + self.id_to_reg.get(&id).copied() + } + + fn get_or_add(&mut self, current_id: &mut spirv::Word, reg: PtxSpecialRegister) -> spirv::Word { + match self.reg_to_id.entry(reg) { + hash_map::Entry::Occupied(e) => *e.get(), + hash_map::Entry::Vacant(e) => { + let numeric_id = *current_id; + *current_id += 1; + e.insert(numeric_id); + self.id_to_reg.insert(numeric_id, reg); + numeric_id + } + } + } + + fn replace( + &mut self, + current_id: &mut spirv::Word, + old: PtxSpecialRegister, + new: PtxSpecialRegister, + ) -> spirv::Word { + match self.reg_to_id.entry(old) { + hash_map::Entry::Occupied(e) => { + let id = e.remove(); + self.reg_to_id.insert(new, id); + id + } + hash_map::Entry::Vacant(e) => { + drop(e); + match self.reg_to_id.entry(new) { + hash_map::Entry::Occupied(e) => *e.get(), + hash_map::Entry::Vacant(e) => { + let numeric_id = *current_id; + *current_id += 1; + e.insert(numeric_id); + self.id_to_reg.insert(numeric_id, new); + numeric_id + } + } + } } } } @@ -4638,7 +4812,7 @@ struct GlobalStringIdResolver<'input> { current_id: spirv::Word, variables: HashMap, spirv::Word>, variables_type_check: HashMap>, - special_registers: HashMap, + special_registers: SpecialRegistersMap, fns: HashMap, } @@ -4653,7 +4827,7 @@ impl<'a> GlobalStringIdResolver<'a> { current_id: start_id, variables: HashMap::new(), variables_type_check: HashMap::new(), - special_registers: HashMap::new(), + special_registers: SpecialRegistersMap::new(), fns: HashMap::new(), } } @@ -4768,7 +4942,7 @@ struct FnStringIdResolver<'input, 'b> { current_id: &'b mut spirv::Word, global_variables: &'b HashMap, spirv::Word>, global_type_check: &'b HashMap>, - special_registers: &'b mut HashMap, + special_registers: &'b mut SpecialRegistersMap, variables: Vec, spirv::Word>>, type_check: HashMap>, } @@ -4779,11 +4953,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { current_id: self.current_id, global_type_check: self.global_type_check, type_check: self.type_check, - special_registers: self - .special_registers - .iter() - .map(|(reg, id)| (*id, *reg)) - .collect(), + special_registers: self.special_registers, } } @@ -4807,15 +4977,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { None => { let sreg = PtxSpecialRegister::try_parse(id).ok_or(TranslateError::UnknownSymbol)?; - match self.special_registers.entry(sreg) { - hash_map::Entry::Occupied(e) => Ok(*e.get()), - hash_map::Entry::Vacant(e) => { - let numeric_id = *self.current_id; - *self.current_id += 1; - e.insert(numeric_id); - Ok(numeric_id) - } - } + Ok(self.special_registers.get_or_add(self.current_id, sreg)) } } } @@ -4858,7 +5020,7 @@ struct NumericIdResolver<'b> { current_id: &'b mut spirv::Word, global_type_check: &'b HashMap>, type_check: HashMap>, - special_registers: HashMap, + special_registers: &'b mut SpecialRegistersMap, } impl<'b> NumericIdResolver<'b> { @@ -4870,7 +5032,7 @@ impl<'b> NumericIdResolver<'b> { match self.type_check.get(&id) { Some(Some(x)) => Ok(x.clone()), Some(None) => Err(TranslateError::UntypedSymbol), - None => match self.special_registers.get(&id) { + None => match self.special_registers.get(id) { Some(x) => Ok((x.get_type(), true)), None => match self.global_type_check.get(&id) { Some(Some(result)) => Ok(result.clone()), @@ -5006,7 +5168,19 @@ impl ExpandedStatement { offset_src: constant_src, }) } - Statement::RepackVector(_) => todo!(), + Statement::RepackVector(repack) => { + let packed = f(repack.packed, !repack.is_extract); + let unpacked = repack + .unpacked + .iter() + .map(|id| f(*id, repack.is_extract)) + .collect(); + Statement::RepackVector(RepackVectorDetails { + packed, + unpacked, + ..repack + }) + } } } }