Skip to content

Commit

Permalink
Assignment to indexed array locations.
Browse files Browse the repository at this point in the history
  • Loading branch information
markkurossi committed Aug 2, 2024
1 parent e8b1441 commit f582537
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 50 deletions.
12 changes: 8 additions & 4 deletions compiler/ast/lrvalue.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ type LRValue struct {

func (lrv LRValue) String() string {
offset := lrv.baseInfo.Offset + lrv.valueType.Offset
return fmt.Sprintf("%s[%d-%d]@%s{%d}%s",
return fmt.Sprintf("%s[%d-%d]@%s{%d}%s/%v",
lrv.valueType, offset, offset+lrv.valueType.Bits,
lrv.baseInfo.Name, lrv.baseInfo.Scope, lrv.baseInfo.ContainerType)
lrv.baseInfo.Name, lrv.baseInfo.Scope, lrv.baseInfo.ContainerType,
lrv.baseInfo.ContainerType.Bits)
}

// BaseType returns the base type of the LRValue.
Expand Down Expand Up @@ -92,20 +93,23 @@ func (lrv *LRValue) Indirect() *LRValue {
ret.value.PtrInfo = nil

if lrv.baseInfo.ContainerType.Type == types.TStruct {
// Set value to undefined so RValue() can regenerate it.
ret.value.Type = types.Undefined

// Lookup struct field.
ret.structField = nil
for _, f := range lrv.baseValue.Type.Struct {
for idx, f := range lrv.baseValue.Type.Struct {
if f.Type.Offset == lrv.baseInfo.Offset {
ret.structField = &f
ret.structField = &lrv.baseValue.Type.Struct[idx]
break
}
}
if ret.structField == nil {
panic("LRValue.Indirect: could not find struct field")
}
} else {
ret.value.Type = *lrv.value.Type.ElementType
ret.value = lrv.baseValue
}

return &ret
Expand Down
113 changes: 71 additions & 42 deletions compiler/ast/ssagen.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package ast
import (
"fmt"
"os"
"slices"

"github.com/markkurossi/mpc/compiler/ssa"
"github.com/markkurossi/mpc/compiler/utils"
Expand Down Expand Up @@ -317,61 +318,80 @@ func (ast *Assign) SSA(block *ssa.Block, ctx *Codegen,
return nil, nil, ctx.Errorf(ast,
"a non-name %s on left side of :=", lv)
}
switch arr := lv.Expr.(type) {
case *VariableRef:
lrv, _, _, err := ctx.LookupVar(block, gen, block.Bindings, arr)
if err != nil {
return nil, nil, ctx.Error(arr, err.Error())
}
valueType := lrv.ValueType()
if valueType.Type == types.TPtr {
valueType = *valueType.ElementType
}

if !valueType.Type.Array() {
return nil, nil, ctx.Errorf(ast,
"setting elements of non-array %s (%s)",
arr, lrv.ValueType())
}
arraySize := valueType.ArraySize
elementSize := valueType.ElementType.Bits
var err error
var v []ssa.Value
var indices []arrayIndex
var lrv *LRValue
idx := lv

block, val, err := lv.Index.SSA(block, ctx, gen)
for lrv == nil {
block, v, err = idx.Index.SSA(block, ctx, gen)
if err != nil {
return nil, nil, err
}
if len(val) != 1 {
return nil, nil, ctx.Errorf(lv.Index, "invalid index")
if len(v) != 1 {
return nil, nil, ctx.Errorf(idx.Index, "invalid index")
}
index, err := val[0].ConstInt()
index, err := v[0].ConstInt()
if err != nil {
return nil, nil, ctx.Errorf(lv.Index, "%s", err)
return nil, nil, ctx.Error(idx.Index, err.Error())
}
indices = append(indices, arrayIndex{
i: index,
ast: idx.Index,
})
switch i := idx.Expr.(type) {
case *Index:
idx = i

case *VariableRef:
lrv, _, _, err = ctx.LookupVar(block, gen,
block.Bindings, i)
if err != nil {
return nil, nil, err
}

default:
return nil, nil, ctx.Errorf(idx.Expr,
"invalid operation: cannot index %v (%T)",
idx.Expr, idx.Expr)

// Convert index to bit range.
if index >= arraySize {
return nil, nil, ctx.Errorf(lv.Index,
"invalid array index %d (out of bounds for %d-element array)",
index, arraySize)
}
basePtrInfo := lrv.BasePtrInfo()
from := int64(index*elementSize + basePtrInfo.Offset)
to := int64(from + int64(elementSize))
}
slices.Reverse(indices)

fromConst := gen.Constant(from, types.Undefined)
toConst := gen.Constant(to, types.Undefined)
lrv = lrv.Indirect()
t := lrv.ValueType()
var offset types.Size

lValue := lrv.LValue()
block.AddInstr(ssa.NewAmovInstr(rv, lrv.BaseValue(),
fromConst, toConst, lValue))
err = basePtrInfo.Bindings.Set(lValue, nil)
if err != nil {
return nil, nil, ctx.Error(ast, err.Error())
for _, index := range indices {
if !t.Type.Array() {
return nil, nil, ctx.Errorf(index.ast,
"setting elements of non-array %s (%s)", lv.Expr, t)
}
if index.i >= t.ArraySize {
return nil, nil, ctx.Errorf(index.ast,
"invalid array index %d (out of bounds for %d-element array)",
index.i, t.ArraySize)
}
offset += index.i * t.ElementType.Bits
t = *t.ElementType
}

default:
return nil, nil, ctx.Errorf(ast,
"array expression not supported: %T", arr)
if !ssa.CanAssign(t, rv) {
return nil, nil, ctx.Errorf(lvalue,
"cannot assign %v to variable of type %v", rv.Type, t)
}

val := gen.AnonVal(lrv.ValueType())
fromConst := gen.Constant(int64(offset), types.Undefined)
toConst := gen.Constant(int64(offset+t.Bits), types.Undefined)
block.AddInstr(ssa.NewAmovInstr(rv, lrv.RValue(), fromConst,
toConst, val))

err = lrv.Set(val)
if err != nil {
return nil, nil, ctx.Error(lvalue, err.Error())
}

case *Unary:
Expand Down Expand Up @@ -444,6 +464,15 @@ func (ast *Assign) SSA(block *ssa.Block, ctx *Codegen,
return block, values, nil
}

type arrayIndex struct {
i types.Size
ast AST
}

func (i arrayIndex) String() string {
return fmt.Sprintf("%d", i.i)
}

// SSA implements the compiler.ast.AST.SSA for if statements.
func (ast *If) SSA(block *ssa.Block, ctx *Codegen, gen *ssa.Generator) (
*ssa.Block, []ssa.Value, error) {
Expand Down
4 changes: 2 additions & 2 deletions pkg/crypto/aes/circuit.mpcl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// -*- go -*-
//
// Copyright (c) 2020-2021 Markku Rossi
// Copyright (c) 2020-2024 Markku Rossi
//
// All rights reserved.
//
Expand All @@ -24,7 +24,7 @@ func Block128(key [16]byte, data [16]byte) [16]byte {
c := block128(k, d)
var cipher [16]byte
for i := len(cipher) - 1; i >= 0; i-- {
cipher[i] = c & 0xff
cipher[i] = byte(c & 0xff)
c >>= 8
}
return cipher
Expand Down
4 changes: 2 additions & 2 deletions pkg/crypto/cipher/gcm/gcm.mpcl
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ func uint128ToByte(x uint128) [16]byte {
var r [16]byte

for i := 0; i < 16; i++ {
r[15-i] = x & 0xff
r[15-i] = byte(x & 0xff)
x >>= 8
}
return r
Expand All @@ -203,7 +203,7 @@ func incr(counter [aes.BlockSize]byte) [aes.BlockSize]byte {
}
c++
for i := 0; i < 4; i++ {
counter[15-i] = c & 0xff
counter[15-i] = byte(c & 0xff)
c >>= 8
}
return counter
Expand Down

0 comments on commit f582537

Please sign in to comment.