diff --git a/runtime/convertValues.go b/runtime/convertValues.go index a71ccc151b..a234ecd290 100644 --- a/runtime/convertValues.go +++ b/runtime/convertValues.go @@ -88,11 +88,12 @@ func exportValueWithInterpreter( case interpreter.BoolValue: return cadence.NewMeteredBool(inter, bool(v)), nil case *interpreter.StringValue: + str := v.Str(inter) return cadence.NewMeteredString( inter, - common.NewCadenceStringMemoryUsage(len(v.Str)), + common.NewCadenceStringMemoryUsage(len(str)), func() string { - return v.Str + return str }, ) case interpreter.CharacterValue: diff --git a/runtime/interpreter/encode.go b/runtime/interpreter/encode.go index 997f6d5bf7..f2adb124b7 100644 --- a/runtime/interpreter/encode.go +++ b/runtime/interpreter/encode.go @@ -275,7 +275,8 @@ func (v *StringValue) Encode(e *atree.Encoder) error { if err != nil { return err } - return e.CBOR.EncodeString(v.Str) + // TODO: write normalized? no memory gauge available + return e.CBOR.EncodeString(v._str) } // Encode encodes the value as a CBOR string diff --git a/runtime/interpreter/interpreter.go b/runtime/interpreter/interpreter.go index 55be367a38..f6af0149f6 100644 --- a/runtime/interpreter/interpreter.go +++ b/runtime/interpreter/interpreter.go @@ -871,7 +871,7 @@ func (interpreter *Interpreter) visitCondition(condition ast.Condition, kind ast var message string if messageExpression != nil { messageValue := interpreter.evalExpression(messageExpression) - message = messageValue.(*StringValue).Str + message = messageValue.(*StringValue).Str(interpreter) } panic(ConditionError{ @@ -2571,7 +2571,7 @@ func newFromStringFunction(ty sema.Type, parser stringValueParser) fromStringFun panic(errors.NewUnreachableError()) } inter := invocation.Interpreter - return parser(inter, argument.Str) + return parser(inter, argument.Str(inter)) }, ) return fromStringFunctionValue{ @@ -3407,6 +3407,8 @@ func dictionaryTypeFunction(invocation Invocation) Value { } func referenceTypeFunction(invocation Invocation) Value { + inter := invocation.Interpreter + entitlementValues, ok := invocation.Arguments[0].(*ArrayValue) if !ok { panic(errors.NewUnreachableError()) @@ -3423,22 +3425,23 @@ func referenceTypeFunction(invocation Invocation) Value { if entitlementsCount > 0 { authorization = NewEntitlementSetAuthorization( - invocation.Interpreter, + inter, func() []common.TypeID { entitlements := make([]common.TypeID, 0, entitlementsCount) - entitlementValues.Iterate(invocation.Interpreter, func(element Value) (resume bool) { + entitlementValues.Iterate(inter, func(element Value) (resume bool) { entitlementString, isString := element.(*StringValue) if !isString { errInIteration = true return false } - _, err := lookupEntitlement(invocation.Interpreter, entitlementString.Str) + entitlementStr := entitlementString.Str(inter) + _, err := lookupEntitlement(inter, entitlementStr) if err != nil { errInIteration = true return false } - entitlements = append(entitlements, common.TypeID(entitlementString.Str)) + entitlements = append(entitlements, common.TypeID(entitlementStr)) return true }) @@ -3454,11 +3457,11 @@ func referenceTypeFunction(invocation Invocation) Value { } return NewSomeValueNonCopying( - invocation.Interpreter, + inter, NewTypeValue( - invocation.Interpreter, + inter, NewReferenceStaticType( - invocation.Interpreter, + inter, authorization, typeValue.Type, ), @@ -3467,43 +3470,47 @@ func referenceTypeFunction(invocation Invocation) Value { } func compositeTypeFunction(invocation Invocation) Value { + inter := invocation.Interpreter + typeIDValue, ok := invocation.Arguments[0].(*StringValue) if !ok { panic(errors.NewUnreachableError()) } - typeID := typeIDValue.Str + typeID := typeIDValue.Str(inter) - composite, err := lookupComposite(invocation.Interpreter, typeID) + composite, err := lookupComposite(inter, typeID) if err != nil { return Nil } return NewSomeValueNonCopying( - invocation.Interpreter, + inter, NewTypeValue( - invocation.Interpreter, - ConvertSemaToStaticType(invocation.Interpreter, composite), + inter, + ConvertSemaToStaticType(inter, composite), ), ) } func interfaceTypeFunction(invocation Invocation) Value { + inter := invocation.Interpreter + typeIDValue, ok := invocation.Arguments[0].(*StringValue) if !ok { panic(errors.NewUnreachableError()) } - typeID := typeIDValue.Str + typeID := typeIDValue.Str(inter) - interfaceType, err := lookupInterface(invocation.Interpreter, typeID) + interfaceType, err := lookupInterface(inter, typeID) if err != nil { return Nil } return NewSomeValueNonCopying( - invocation.Interpreter, + inter, NewTypeValue( - invocation.Interpreter, - ConvertSemaToStaticType(invocation.Interpreter, interfaceType), + inter, + ConvertSemaToStaticType(inter, interfaceType), ), ) } @@ -3552,6 +3559,8 @@ func functionTypeFunction(invocation Invocation) Value { } func intersectionTypeFunction(invocation Invocation) Value { + inter := invocation.Interpreter + intersectionIDs, ok := invocation.Arguments[0].(*ArrayValue) if !ok { panic(errors.NewUnreachableError()) @@ -3566,13 +3575,13 @@ func intersectionTypeFunction(invocation Invocation) Value { semaIntersections = make([]*sema.InterfaceType, 0, count) var invalidIntersectionID bool - intersectionIDs.Iterate(invocation.Interpreter, func(typeID Value) bool { + intersectionIDs.Iterate(inter, func(typeID Value) bool { typeIDValue, ok := typeID.(*StringValue) if !ok { panic(errors.NewUnreachableError()) } - intersectedInterface, err := lookupInterface(invocation.Interpreter, typeIDValue.Str) + intersectedInterface, err := lookupInterface(inter, typeIDValue.Str(inter)) if err != nil { invalidIntersectionID = true return true @@ -3580,7 +3589,7 @@ func intersectionTypeFunction(invocation Invocation) Value { staticIntersections = append( staticIntersections, - ConvertSemaToStaticType(invocation.Interpreter, intersectedInterface).(*InterfaceStaticType), + ConvertSemaToStaticType(inter, intersectedInterface).(*InterfaceStaticType), ) semaIntersections = append(semaIntersections, intersectedInterface) @@ -3597,7 +3606,7 @@ func intersectionTypeFunction(invocation Invocation) Value { var invalidIntersectionType bool sema.CheckIntersectionType( - invocation.Interpreter, + inter, semaIntersections, func(_ func(*ast.IntersectionType) error) { invalidIntersectionType = true @@ -3611,11 +3620,11 @@ func intersectionTypeFunction(invocation Invocation) Value { } return NewSomeValueNonCopying( - invocation.Interpreter, + inter, NewTypeValue( - invocation.Interpreter, + inter, NewIntersectionStaticType( - invocation.Interpreter, + inter, staticIntersections, ), ), diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index f0aa2a80aa..93906a88d6 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -505,7 +505,7 @@ func (TypeValue) ChildStorables() []atree.Storable { // HashInput returns a byte slice containing: // - HashInputTypeType (1 byte) // - type id (n bytes) -func (v TypeValue) HashInput(interpreter *Interpreter, _ LocationRange, scratch []byte) []byte { +func (v TypeValue) HashInput(_ *Interpreter, _ LocationRange, scratch []byte) []byte { typeID := v.Type.ID() length := 1 + len(typeID) @@ -1014,15 +1014,20 @@ type StringValue struct { // which is initialized lazily and reused/reset in functions // that are based on grapheme clusters graphemes *uniseg.Graphemes - Str string + // Deprecated: Use Str(). + // _str is the raw string value. + // At construction, it is not normalized yet. + // On use (via Str()), it gets normalized. + _str string // length is the cached length of the string, based on grapheme clusters. // a negative value indicates the length has not been initialized, see Length() - length int + length int + normalized bool } func NewUnmeteredStringValue(str string) *StringValue { return &StringValue{ - Str: str, + _str: str, // a negative value indicates the length has not been initialized, see Length() length: -1, } @@ -1047,9 +1052,9 @@ var _ ValueIndexableValue = &StringValue{} var _ MemberAccessibleValue = &StringValue{} var _ IterableValue = &StringValue{} -func (v *StringValue) prepareGraphemes() { +func (v *StringValue) prepareGraphemes(memoryGauge common.MemoryGauge) { if v.graphemes == nil { - v.graphemes = uniseg.NewGraphemes(v.Str) + v.graphemes = uniseg.NewGraphemes(v.Str(memoryGauge)) } else { v.graphemes.Reset() } @@ -1074,7 +1079,7 @@ func (*StringValue) IsImportable(_ *Interpreter) bool { } func (v *StringValue) String() string { - return format.String(v.Str) + return format.String(v.Str(nil)) } func (v *StringValue) RecursiveString(_ SeenReferences) string { @@ -1082,17 +1087,17 @@ func (v *StringValue) RecursiveString(_ SeenReferences) string { } func (v *StringValue) MeteredString(memoryGauge common.MemoryGauge, _ SeenReferences) string { - l := format.FormattedStringLength(v.Str) + l := format.FormattedStringLength(v.Str(memoryGauge)) common.UseMemory(memoryGauge, common.NewRawStringMemoryUsage(l)) return v.String() } -func (v *StringValue) Equal(_ *Interpreter, _ LocationRange, other Value) bool { +func (v *StringValue) Equal(interpreter *Interpreter, _ LocationRange, other Value) bool { otherString, ok := other.(*StringValue) if !ok { return false } - return v.NormalForm() == otherString.NormalForm() + return v.Str(interpreter) == otherString.Str(interpreter) } func (v *StringValue) Less(interpreter *Interpreter, other ComparableValue, locationRange LocationRange) BoolValue { @@ -1105,8 +1110,7 @@ func (v *StringValue) Less(interpreter *Interpreter, other ComparableValue, loca LocationRange: locationRange, }) } - - return AsBoolValue(v.NormalForm() < otherString.NormalForm()) + return AsBoolValue(v.Str(interpreter) < otherString.Str(interpreter)) } func (v *StringValue) LessEqual(interpreter *Interpreter, other ComparableValue, locationRange LocationRange) BoolValue { @@ -1120,7 +1124,7 @@ func (v *StringValue) LessEqual(interpreter *Interpreter, other ComparableValue, }) } - return AsBoolValue(v.NormalForm() <= otherString.NormalForm()) + return AsBoolValue(v.Str(interpreter) <= otherString.Str(interpreter)) } func (v *StringValue) Greater(interpreter *Interpreter, other ComparableValue, locationRange LocationRange) BoolValue { @@ -1134,7 +1138,7 @@ func (v *StringValue) Greater(interpreter *Interpreter, other ComparableValue, l }) } - return AsBoolValue(v.NormalForm() > otherString.NormalForm()) + return AsBoolValue(v.Str(interpreter) > otherString.Str(interpreter)) } func (v *StringValue) GreaterEqual(interpreter *Interpreter, other ComparableValue, locationRange LocationRange) BoolValue { @@ -1148,14 +1152,14 @@ func (v *StringValue) GreaterEqual(interpreter *Interpreter, other ComparableVal }) } - return AsBoolValue(v.NormalForm() >= otherString.NormalForm()) + return AsBoolValue(v.Str(interpreter) >= otherString.Str(interpreter)) } // HashInput returns a byte slice containing: // - HashInputTypeString (1 byte) // - string value (n bytes) -func (v *StringValue) HashInput(_ *Interpreter, _ LocationRange, scratch []byte) []byte { - length := 1 + len(v.Str) +func (v *StringValue) HashInput(inter *Interpreter, _ LocationRange, scratch []byte) []byte { + length := 1 + len(v.Str(inter)) var buffer []byte if length <= len(scratch) { buffer = scratch[:length] @@ -1164,18 +1168,26 @@ func (v *StringValue) HashInput(_ *Interpreter, _ LocationRange, scratch []byte) } buffer[0] = byte(HashInputTypeString) - copy(buffer[1:], v.Str) + copy(buffer[1:], v.Str(inter)) return buffer } -func (v *StringValue) NormalForm() string { - return norm.NFC.String(v.Str) +func (v *StringValue) Str(memoryGauge common.MemoryGauge) string { + if !v.normalized { + common.UseMemory( + memoryGauge, + common.NewRawStringMemoryUsage(len(v._str)), + ) + v._str = norm.NFC.String(v._str) + v.normalized = true + } + return v._str } func (v *StringValue) Concat(interpreter *Interpreter, other *StringValue, locationRange LocationRange) Value { - firstLength := len(v.Str) - secondLength := len(other.Str) + firstLength := len(v.Str(interpreter)) + secondLength := len(other.Str(interpreter)) newLength := safeAdd(firstLength, secondLength, locationRange) @@ -1187,8 +1199,8 @@ func (v *StringValue) Concat(interpreter *Interpreter, other *StringValue, locat func() string { var sb strings.Builder - sb.WriteString(v.Str) - sb.WriteString(other.Str) + sb.WriteString(v.Str(interpreter)) + sb.WriteString(other.Str(interpreter)) return sb.String() }, @@ -1197,12 +1209,12 @@ func (v *StringValue) Concat(interpreter *Interpreter, other *StringValue, locat var EmptyString = NewUnmeteredStringValue("") -func (v *StringValue) Slice(from IntValue, to IntValue, locationRange LocationRange) Value { +func (v *StringValue) Slice(interpreter *Interpreter, from IntValue, to IntValue, locationRange LocationRange) Value { fromIndex := from.ToInt(locationRange) toIndex := to.ToInt(locationRange) - length := v.Length() + length := v.Length(interpreter) if fromIndex < 0 || fromIndex > length || toIndex < 0 || toIndex > length { panic(StringSliceIndicesError{ @@ -1225,7 +1237,7 @@ func (v *StringValue) Slice(from IntValue, to IntValue, locationRange LocationRa return EmptyString } - v.prepareGraphemes() + v.prepareGraphemes(interpreter) j := 0 @@ -1241,11 +1253,11 @@ func (v *StringValue) Slice(from IntValue, to IntValue, locationRange LocationRa // NOTE: string slicing in Go does not copy, // see https://stackoverflow.com/questions/52395730/does-slice-of-string-perform-copy-of-underlying-data - return NewUnmeteredStringValue(v.Str[start:end]) + return NewUnmeteredStringValue(v.Str(interpreter)[start:end]) } -func (v *StringValue) checkBounds(index int, locationRange LocationRange) { - length := v.Length() +func (v *StringValue) checkBounds(memoryGauge common.MemoryGauge, index int, locationRange LocationRange) { + length := v.Length(memoryGauge) if index < 0 || index >= length { panic(StringIndexOutOfBoundsError{ @@ -1258,9 +1270,9 @@ func (v *StringValue) checkBounds(index int, locationRange LocationRange) { func (v *StringValue) GetKey(interpreter *Interpreter, locationRange LocationRange, key Value) Value { index := key.(NumberValue).ToInt(locationRange) - v.checkBounds(index, locationRange) + v.checkBounds(interpreter, index, locationRange) - v.prepareGraphemes() + v.prepareGraphemes(interpreter) for j := 0; j <= index; j++ { v.graphemes.Next() @@ -1288,14 +1300,14 @@ func (*StringValue) RemoveKey(_ *Interpreter, _ LocationRange, _ Value) Value { panic(errors.NewUnreachableError()) } -func (v *StringValue) GetMember(interpreter *Interpreter, locationRange LocationRange, name string) Value { +func (v *StringValue) GetMember(interpreter *Interpreter, _ LocationRange, name string) Value { switch name { case sema.StringTypeLengthFieldName: - length := v.Length() + length := v.Length(interpreter) return NewIntValueFromInt64(interpreter, int64(length)) case sema.StringTypeUtf8FieldName: - return ByteSliceToByteArrayValue(interpreter, []byte(v.Str)) + return ByteSliceToByteArrayValue(interpreter, []byte(v.Str(interpreter))) case sema.StringTypeConcatFunctionName: return NewHostFunctionValue( @@ -1303,6 +1315,8 @@ func (v *StringValue) GetMember(interpreter *Interpreter, locationRange Location sema.StringTypeConcatFunctionType, func(invocation Invocation) Value { interpreter := invocation.Interpreter + locationRange := invocation.LocationRange + otherArray, ok := invocation.Arguments[0].(*StringValue) if !ok { panic(errors.NewUnreachableError()) @@ -1316,6 +1330,9 @@ func (v *StringValue) GetMember(interpreter *Interpreter, locationRange Location interpreter, sema.StringTypeSliceFunctionType, func(invocation Invocation) Value { + interpreter := invocation.Interpreter + locationRange := invocation.LocationRange + from, ok := invocation.Arguments[0].(IntValue) if !ok { panic(errors.NewUnreachableError()) @@ -1326,7 +1343,7 @@ func (v *StringValue) GetMember(interpreter *Interpreter, locationRange Location panic(errors.NewUnreachableError()) } - return v.Slice(from, to, invocation.LocationRange) + return v.Slice(interpreter, from, to, locationRange) }, ) @@ -1366,10 +1383,10 @@ func (*StringValue) SetMember(_ *Interpreter, _ LocationRange, _ string, _ Value } // Length returns the number of characters (grapheme clusters) -func (v *StringValue) Length() int { +func (v *StringValue) Length(memoryGauge common.MemoryGauge) int { if v.length < 0 { var length int - v.prepareGraphemes() + v.prepareGraphemes(memoryGauge) for v.graphemes.Next() { length++ } @@ -1380,12 +1397,14 @@ func (v *StringValue) Length() int { func (v *StringValue) ToLower(interpreter *Interpreter) *StringValue { + str := v.Str(interpreter) + // Over-estimate resulting string length, // as an uppercase character may be converted to several lower-case characters, e.g İ => [i, ̇] // see https://stackoverflow.com/questions/28683805/is-there-a-unicode-string-which-gets-longer-when-converted-to-lowercase var lengthEstimate int - for _, r := range v.Str { + for _, r := range str { if r < unicode.MaxASCII { lengthEstimate += 1 } else { @@ -1399,7 +1418,7 @@ func (v *StringValue) ToLower(interpreter *Interpreter) *StringValue { interpreter, memoryUsage, func() string { - return strings.ToLower(v.Str) + return strings.ToLower(str) }, ) } @@ -1430,8 +1449,8 @@ func (v *StringValue) Transfer( return v } -func (v *StringValue) Clone(_ *Interpreter) Value { - return NewUnmeteredStringValue(v.Str) +func (v *StringValue) Clone(inter *Interpreter) Value { + return NewUnmeteredStringValue(v.Str(inter)) } func (*StringValue) DeepRemove(_ *Interpreter) { @@ -1439,7 +1458,8 @@ func (*StringValue) DeepRemove(_ *Interpreter) { } func (v *StringValue) ByteSize() uint32 { - return cborTagSize + getBytesCBORSize([]byte(v.Str)) + // TODO: write normalized? no memory gauge available + return cborTagSize + getBytesCBORSize([]byte(v._str)) } func (v *StringValue) StoredValue(_ atree.SlabStorage) (atree.Value, error) { @@ -1455,7 +1475,7 @@ var ByteArrayStaticType = ConvertSemaArrayTypeToStaticArrayType(nil, sema.ByteAr // DecodeHex hex-decodes this string and returns an array of UInt8 values func (v *StringValue) DecodeHex(interpreter *Interpreter, locationRange LocationRange) *ArrayValue { - bs, err := hex.DecodeString(v.Str) + bs, err := hex.DecodeString(v.Str(interpreter)) if err != nil { if err, ok := err.(hex.InvalidByteError); ok { panic(InvalidHexByteError{ @@ -1507,9 +1527,9 @@ func (v *StringValue) ConformsToStaticType( return true } -func (v *StringValue) Iterator(_ *Interpreter) ValueIterator { +func (v *StringValue) Iterator(interpreter *Interpreter) ValueIterator { return StringValueIterator{ - graphemes: uniseg.NewGraphemes(v.Str), + graphemes: uniseg.NewGraphemes(v.Str(interpreter)), } } @@ -20603,17 +20623,18 @@ func AddressFromBytes(invocation Invocation) Value { } func AddressFromString(invocation Invocation) Value { + inter := invocation.Interpreter + argument, ok := invocation.Arguments[0].(*StringValue) if !ok { panic(errors.NewUnreachableError()) } - addr, err := common.HexToAddressAssertPrefix(argument.Str) + addr, err := common.HexToAddressAssertPrefix(argument.Str(inter)) if err != nil { return Nil } - inter := invocation.Interpreter return NewSomeValueNonCopying(inter, NewAddressValue(inter, addr)) } @@ -20785,9 +20806,11 @@ func convertPath(interpreter *Interpreter, domain common.PathDomain, value Value return Nil } + str := stringValue.Str(interpreter) + _, err := sema.CheckPathLiteral( domain.Identifier(), - stringValue.Str, + str, ReturnEmptyRange, ReturnEmptyRange, ) @@ -20800,7 +20823,7 @@ func convertPath(interpreter *Interpreter, domain common.PathDomain, value Value NewPathValue( interpreter, domain, - stringValue.Str, + str, ), ) } diff --git a/runtime/interpreter/value_deployedcontract.go b/runtime/interpreter/value_deployedcontract.go index 745c9fc3bf..02bddbf977 100644 --- a/runtime/interpreter/value_deployedcontract.go +++ b/runtime/interpreter/value_deployedcontract.go @@ -57,6 +57,10 @@ func NewDeployedContractValue( ) } +var MetaTypeArrayStaticType = &VariableSizedStaticType{ + Type: PrimitiveStaticTypeMetaType, +} + func newPublicTypesFunctionValue(inter *Interpreter, addressValue AddressValue, name *StringValue) FunctionValue { // public types only need to be computed once per contract var publicTypes *ArrayValue @@ -67,12 +71,12 @@ func newPublicTypesFunctionValue(inter *Interpreter, addressValue AddressValue, sema.DeployedContractTypePublicTypesFunctionType, func(inv Invocation) Value { if publicTypes == nil { - innerInter := inv.Interpreter - contractLocation := common.NewAddressLocation(innerInter, address, name.Str) + inter := inv.Interpreter + qualifiedIdentifier := name.Str(inter) + contractLocation := common.NewAddressLocation(inter, address, qualifiedIdentifier) // we're only looking at the contract as a whole, so no need to construct a nested path - qualifiedIdent := name.Str - typeID := common.NewTypeIDFromQualifiedName(innerInter, contractLocation, qualifiedIdent) - compositeType, err := innerInter.GetCompositeType(contractLocation, qualifiedIdent, typeID) + typeID := common.NewTypeIDFromQualifiedName(inter, contractLocation, qualifiedIdentifier) + compositeType, err := inter.GetCompositeType(contractLocation, qualifiedIdentifier, typeID) if err != nil { panic(err) } @@ -85,14 +89,14 @@ func newPublicTypesFunctionValue(inter *Interpreter, addressValue AddressValue, if pair == nil { return nil } - typeValue := NewTypeValue(innerInter, ConvertSemaToStaticType(innerInter, pair.Value)) + typeValue := NewTypeValue(inter, ConvertSemaToStaticType(inter, pair.Value)) pair = pair.Next() return typeValue } publicTypes = NewArrayValueWithIterator( - innerInter, - NewVariableSizedStaticType(innerInter, PrimitiveStaticTypeMetaType), + inter, + MetaTypeArrayStaticType, common.Address{}, uint64(nestedTypes.Len()), yieldNext, diff --git a/runtime/runtime_memory_metering_test.go b/runtime/runtime_memory_metering_test.go index 9ec4886c76..be8beb0280 100644 --- a/runtime/runtime_memory_metering_test.go +++ b/runtime/runtime_memory_metering_test.go @@ -97,7 +97,7 @@ func TestRuntimeInterpreterAddressLocationMetering(t *testing.T) { assert.Equal(t, uint64(1), meter.getMemory(common.MemoryKindAddressLocation)) assert.Equal(t, uint64(2), meter.getMemory(common.MemoryKindElaboration)) - assert.Equal(t, uint64(0), meter.getMemory(common.MemoryKindRawString)) + assert.Equal(t, uint64(21), meter.getMemory(common.MemoryKindRawString)) assert.Equal(t, uint64(1), meter.getMemory(common.MemoryKindCadenceVoidValue)) }) } @@ -783,10 +783,8 @@ func TestRuntimeLogFunctionStringConversionMetering(t *testing.T) { getSigningAccounts: func() ([]Address, error) { return []Address{{42}}, nil }, - storage: newTestLedger(nil, nil), - meterMemory: func(usage common.MemoryUsage) error { - return meter.MeterMemory(usage) - }, + storage: newTestLedger(nil, nil), + meterMemory: meter.MeterMemory, getAccountContractCode: func(location common.AddressLocation) (code []byte, err error) { return accountCode, nil }, @@ -818,7 +816,7 @@ func TestRuntimeLogFunctionStringConversionMetering(t *testing.T) { diffOfActualLen := nonEmptyStrActualLen - emptyStrActualLen diffOfMeteredAmount := nonEmptyStrMeteredAmount - emptyStrMeteredAmount - assert.Equal(t, diffOfActualLen, diffOfMeteredAmount) + assert.Equal(t, diffOfActualLen*2, diffOfMeteredAmount) } func TestRuntimeStorageCommitsMetering(t *testing.T) { diff --git a/runtime/stdlib/account.go b/runtime/stdlib/account.go index 57a0071593..246ef51fff 100644 --- a/runtime/stdlib/account.go +++ b/runtime/stdlib/account.go @@ -952,7 +952,7 @@ func newAccountInboxPublishFunction( nil, ) - storageMapKey := interpreter.StringStorageMapKey(nameValue.Str) + storageMapKey := interpreter.StringStorageMapKey(nameValue.Str(inter)) inter.WriteStored( provider, @@ -984,7 +984,7 @@ func newAccountInboxUnpublishFunction( inter := invocation.Interpreter locationRange := invocation.LocationRange - storageMapKey := interpreter.StringStorageMapKey(nameValue.Str) + storageMapKey := interpreter.StringStorageMapKey(nameValue.Str(inter)) readValue := inter.ReadStored(provider, InboxStorageDomain, storageMapKey) if readValue == nil { @@ -1065,7 +1065,7 @@ func newAccountInboxClaimFunction( providerAddress := providerValue.ToAddress() - storageMapKey := interpreter.StringStorageMapKey(nameValue.Str) + storageMapKey := interpreter.StringStorageMapKey(nameValue.Str(inter)) readValue := inter.ReadStored(providerAddress, InboxStorageDomain, storageMapKey) if readValue == nil { @@ -1220,12 +1220,14 @@ func newAccountContractsGetFunction( gauge, functionType, func(invocation interpreter.Invocation) interpreter.Value { + inter := invocation.Interpreter + nameValue, ok := invocation.Arguments[0].(*interpreter.StringValue) if !ok { panic(errors.NewUnreachableError()) } - name := nameValue.Str - location := common.NewAddressLocation(invocation.Interpreter, address, name) + + location := common.NewAddressLocation(inter, address, nameValue.Str(inter)) var code []byte var err error @@ -1238,13 +1240,13 @@ func newAccountContractsGetFunction( if len(code) > 0 { return interpreter.NewSomeValueNonCopying( - invocation.Interpreter, + inter, interpreter.NewDeployedContractValue( - invocation.Interpreter, + inter, addressValue, nameValue, interpreter.ByteSliceToByteArrayValue( - invocation.Interpreter, + inter, code, ), ), @@ -1277,8 +1279,9 @@ func newAccountContractsBorrowFunction( if !ok { panic(errors.NewUnreachableError()) } - name := nameValue.Str - location := common.NewAddressLocation(invocation.Interpreter, address, name) + + name := nameValue.Str(inter) + location := common.NewAddressLocation(inter, address, name) typeParameterPair := invocation.TypeParameterTypes.Oldest() if typeParameterPair == nil { @@ -1387,6 +1390,7 @@ func changeAccountContracts( addressValue interpreter.AddressValue, isUpdate bool, ) interpreter.Value { + inter := invocation.Interpreter locationRange := invocation.LocationRange @@ -1405,14 +1409,14 @@ func changeAccountContracts( constructorArguments := invocation.Arguments[requiredArgumentCount:] constructorArgumentTypes := invocation.ArgumentTypes[requiredArgumentCount:] - code, err := interpreter.ByteArrayValueToByteSlice(invocation.Interpreter, newCodeValue, locationRange) + code, err := interpreter.ByteArrayValueToByteSlice(inter, newCodeValue, locationRange) if err != nil { panic(errors.NewDefaultUserError("add requires the second argument to be an array")) } // Get the existing code - contractName := nameValue.Str + contractName := nameValue.Str(inter) if contractName == "" { panic(errors.NewDefaultUserError( @@ -1422,7 +1426,7 @@ func changeAccountContracts( } address := addressValue.ToAddress() - location := common.NewAddressLocation(invocation.Interpreter, address, contractName) + location := common.NewAddressLocation(inter, address, contractName) existingCode, err := handler.GetAccountContractCode(location) if err != nil { @@ -1560,7 +1564,7 @@ func changeAccountContracts( handleContractUpdateError(err) oldProgram, err := parser.ParseProgram( - invocation.Interpreter.SharedState.Config.MemoryGauge, + inter.SharedState.Config.MemoryGauge, oldCode, parser.Config{ IgnoreLeadingIdentifierEnabled: true, @@ -1581,8 +1585,6 @@ func changeAccountContracts( handleContractUpdateError(err) } - inter := invocation.Interpreter - err = updateAccountContractCode( handler, location, @@ -1937,14 +1939,15 @@ func newAccountContractsRemoveFunction( gauge, sema.Account_ContractsTypeRemoveFunctionType, func(invocation interpreter.Invocation) interpreter.Value { - inter := invocation.Interpreter + nameValue, ok := invocation.Arguments[0].(*interpreter.StringValue) if !ok { panic(errors.NewUnreachableError()) } - name := nameValue.Str - location := common.NewAddressLocation(invocation.Interpreter, address, name) + + name := nameValue.Str(inter) + location := common.NewAddressLocation(inter, address, name) // Get the current code diff --git a/runtime/stdlib/assert.go b/runtime/stdlib/assert.go index e730e771ba..d5891d5311 100644 --- a/runtime/stdlib/assert.go +++ b/runtime/stdlib/assert.go @@ -58,6 +58,8 @@ var AssertFunction = NewStandardLibraryFunction( assertFunctionType, assertFunctionDocString, func(invocation interpreter.Invocation) interpreter.Value { + inter := invocation.Interpreter + result, ok := invocation.Arguments[0].(interpreter.BoolValue) if !ok { panic(errors.NewUnreachableError()) @@ -70,7 +72,7 @@ var AssertFunction = NewStandardLibraryFunction( if !ok { panic(errors.NewUnreachableError()) } - message = messageValue.Str + message = messageValue.Str(inter) } panic(AssertionError{ Message: message, diff --git a/runtime/stdlib/hashalgorithm.go b/runtime/stdlib/hashalgorithm.go index 65019f84e3..5ecbd613d5 100644 --- a/runtime/stdlib/hashalgorithm.go +++ b/runtime/stdlib/hashalgorithm.go @@ -144,7 +144,7 @@ func hash( var tag string if tagValue != nil { - tag = tagValue.Str + tag = tagValue.Str(inter) } hashAlgorithm := NewHashAlgorithmFromValue(inter, locationRange, hashAlgorithmValue) diff --git a/runtime/stdlib/panic.go b/runtime/stdlib/panic.go index d9f6296780..931bc8f966 100644 --- a/runtime/stdlib/panic.go +++ b/runtime/stdlib/panic.go @@ -60,15 +60,18 @@ var PanicFunction = NewStandardLibraryFunction( panicFunctionType, panicFunctionDocString, func(invocation interpreter.Invocation) interpreter.Value { + inter := invocation.Interpreter + locationRange := invocation.LocationRange + messageValue, ok := invocation.Arguments[0].(*interpreter.StringValue) if !ok { panic(errors.NewUnreachableError()) } - message := messageValue.Str + message := messageValue.Str(inter) panic(PanicError{ Message: message, - LocationRange: invocation.LocationRange, + LocationRange: locationRange, }) }, ) diff --git a/runtime/stdlib/publickey.go b/runtime/stdlib/publickey.go index 5b02f4250d..932034eee3 100644 --- a/runtime/stdlib/publickey.go +++ b/runtime/stdlib/publickey.go @@ -274,7 +274,7 @@ func newPublicKeyVerifySignatureFunction( panic(errors.NewUnexpectedError("failed to get signed data. %w", err)) } - domainSeparationTag := domainSeparationTagValue.Str + domainSeparationTag := domainSeparationTagValue.Str(inter) hashAlgorithm := NewHashAlgorithmFromValue(inter, locationRange, hashAlgorithmValue) diff --git a/runtime/stdlib/test_contract.go b/runtime/stdlib/test_contract.go index 488f1f24f4..943366d749 100644 --- a/runtime/stdlib/test_contract.go +++ b/runtime/stdlib/test_contract.go @@ -77,6 +77,8 @@ var testTypeAssertFunctionType = &sema.FunctionType{ var testTypeAssertFunction = interpreter.NewUnmeteredHostFunctionValue( testTypeAssertFunctionType, func(invocation interpreter.Invocation) interpreter.Value { + inter := invocation.Interpreter + condition, ok := invocation.Arguments[0].(interpreter.BoolValue) if !ok { panic(errors.NewUnreachableError()) @@ -88,7 +90,7 @@ var testTypeAssertFunction = interpreter.NewUnmeteredHostFunctionValue( if !ok { panic(errors.NewUnreachableError()) } - message = messageValue.Str + message = messageValue.Str(inter) } if !condition { @@ -194,13 +196,15 @@ var testTypeFailFunctionType = &sema.FunctionType{ var testTypeFailFunction = interpreter.NewUnmeteredHostFunctionValue( testTypeFailFunctionType, func(invocation interpreter.Invocation) interpreter.Value { + inter := invocation.Interpreter + var message string if len(invocation.Arguments) > 0 { messageValue, ok := invocation.Arguments[0].(*interpreter.StringValue) if !ok { panic(errors.NewUnreachableError()) } - message = messageValue.Str + message = messageValue.Str(inter) } panic(AssertionError{ @@ -351,12 +355,14 @@ func newTestTypeReadFileFunction(testFramework TestFramework) *interpreter.HostF return interpreter.NewUnmeteredHostFunctionValue( testTypeReadFileFunctionType, func(invocation interpreter.Invocation) interpreter.Value { + inter := invocation.Interpreter + pathString, ok := invocation.Arguments[0].(*interpreter.StringValue) if !ok { panic(errors.NewUnreachableError()) } - content, err := testFramework.ReadFile(pathString.Str) + content, err := testFramework.ReadFile(pathString.Str(inter)) if err != nil { panic(err) } @@ -903,7 +909,10 @@ func newTestTypeExpectFailureFunction( defer inter.RecoverErrors(func(internalErr error) { if !failedAsExpected { panic(internalErr) - } else if !strings.Contains(internalErr.Error(), errorMessage.Str) { + } else if !strings.Contains( + internalErr.Error(), + errorMessage.Str(inter), + ) { msg := fmt.Sprintf( "Expected error message to include: %s.", errorMessage, diff --git a/runtime/stdlib/test_emulatorbackend.go b/runtime/stdlib/test_emulatorbackend.go index aca8706e56..46af4859f9 100644 --- a/runtime/stdlib/test_emulatorbackend.go +++ b/runtime/stdlib/test_emulatorbackend.go @@ -269,7 +269,11 @@ func (t *testEmulatorBackendType) newExecuteScriptFunction( panic(errors.NewUnexpectedErrorFromCause(err)) } - result := blockchain.RunScript(inter, script.Str, args) + result := blockchain.RunScript( + inter, + script.Str(inter), + args, + ) return newScriptResult(inter, result.Value, result) }, @@ -421,7 +425,7 @@ func (t *testEmulatorBackendType) newAddTransactionFunction( err = blockchain.AddTransaction( inter, - code.Str, + code.Str(inter), authorizers, signerAccounts, args, @@ -531,8 +535,8 @@ func (t *testEmulatorBackendType) newDeployContractFunction( err = blockchain.DeployContract( inter, - name.Str, - code.Str, + name.Str(inter), + code.Str(inter), account, args, ) @@ -587,7 +591,9 @@ func (t *testEmulatorBackendType) newUseConfigFunction( panic(errors.NewUnreachableError()) } - mapping[location.Str] = common.Address(address) + locationStr := location.Str(inter) + + mapping[locationStr] = common.Address(address) return true }) @@ -780,12 +786,14 @@ func (t *testEmulatorBackendType) newCreateSnapshotFunction( return interpreter.NewUnmeteredHostFunctionValue( t.createSnapshotFunctionType, func(invocation interpreter.Invocation) interpreter.Value { + inter := invocation.Interpreter + name, ok := invocation.Arguments[0].(*interpreter.StringValue) if !ok { panic(errors.NewUnreachableError()) } - err := blockchain.CreateSnapshot(name.Str) + err := blockchain.CreateSnapshot(name.Str(inter)) return newErrorValue(invocation.Interpreter, err) }, ) @@ -806,12 +814,14 @@ func (t *testEmulatorBackendType) newLoadSnapshotFunction( return interpreter.NewUnmeteredHostFunctionValue( t.loadSnapshotFunctionType, func(invocation interpreter.Invocation) interpreter.Value { + inter := invocation.Interpreter + name, ok := invocation.Arguments[0].(*interpreter.StringValue) if !ok { panic(errors.NewUnreachableError()) } - err := blockchain.LoadSnapshot(name.Str) + err := blockchain.LoadSnapshot(name.Str(inter)) return newErrorValue(invocation.Interpreter, err) }, ) diff --git a/runtime/tests/interpreter/container_mutation_test.go b/runtime/tests/interpreter/container_mutation_test.go index 904ff2ece3..9499699a19 100644 --- a/runtime/tests/interpreter/container_mutation_test.go +++ b/runtime/tests/interpreter/container_mutation_test.go @@ -539,7 +539,11 @@ func TestInterpretDictionaryMutation(t *testing.T) { interpreter.NewUnmeteredStringValue("foo"), ) assert.True(t, present) - assert.Equal(t, interpreter.NewUnmeteredStringValue("baz"), val) + AssertValuesEqual(t, + inter, + interpreter.NewUnmeteredStringValue("baz"), + val, + ) }) t.Run("simple dictionary invalid", func(t *testing.T) { @@ -618,7 +622,11 @@ func TestInterpretDictionaryMutation(t *testing.T) { interpreter.NewUnmeteredStringValue("foo"), ) assert.True(t, present) - assert.Equal(t, interpreter.NewUnmeteredStringValue("baz"), val) + AssertValuesEqual(t, + inter, + interpreter.NewUnmeteredStringValue("baz"), + val, + ) }) t.Run("dictionary insert invalid", func(t *testing.T) { @@ -1058,7 +1066,11 @@ func TestInterpretInnerContainerMutationWhileIteratingOuter(t *testing.T) { interpreter.NewUnmeteredStringValue("name"), ) assert.True(t, present) - assert.Equal(t, interpreter.NewUnmeteredStringValue("hello"), val) + AssertValuesEqual(t, + inter, + interpreter.NewUnmeteredStringValue("hello"), + val, + ) }) t.Run("dictionary", func(t *testing.T) { @@ -1093,6 +1105,10 @@ func TestInterpretInnerContainerMutationWhileIteratingOuter(t *testing.T) { interpreter.NewUnmeteredStringValue("name"), ) assert.True(t, present) - assert.Equal(t, interpreter.NewUnmeteredStringValue("foo"), val) + AssertValuesEqual(t, + inter, + interpreter.NewUnmeteredStringValue("foo"), + val, + ) }) } diff --git a/runtime/tests/interpreter/interface_test.go b/runtime/tests/interpreter/interface_test.go index 59ffb437f2..010bc540e5 100644 --- a/runtime/tests/interpreter/interface_test.go +++ b/runtime/tests/interpreter/interface_test.go @@ -578,7 +578,8 @@ func TestInterpretInterfaceFunctionConditionsInheritance(t *testing.T) { logFunctionType, "", func(invocation interpreter.Invocation) interpreter.Value { - msg := invocation.Arguments[0].(*interpreter.StringValue).Str + inter := invocation.Interpreter + msg := invocation.Arguments[0].(*interpreter.StringValue).Str(inter) logs = append(logs, msg) return interpreter.Void }, @@ -686,7 +687,8 @@ func TestInterpretInterfaceFunctionConditionsInheritance(t *testing.T) { logFunctionType, "", func(invocation interpreter.Invocation) interpreter.Value { - msg := invocation.Arguments[0].(*interpreter.StringValue).Str + inter := invocation.Interpreter + msg := invocation.Arguments[0].(*interpreter.StringValue).Str(inter) logs = append(logs, msg) return interpreter.Void }, @@ -794,7 +796,8 @@ func TestInterpretInterfaceFunctionConditionsInheritance(t *testing.T) { logFunctionType, "", func(invocation interpreter.Invocation) interpreter.Value { - msg := invocation.Arguments[0].(*interpreter.StringValue).Str + inter := invocation.Interpreter + msg := invocation.Arguments[0].(*interpreter.StringValue).Str(inter) logs = append(logs, msg) return interpreter.Void }, diff --git a/runtime/tests/interpreter/interpreter_test.go b/runtime/tests/interpreter/interpreter_test.go index 6f2bb2e5d3..4147a7c796 100644 --- a/runtime/tests/interpreter/interpreter_test.go +++ b/runtime/tests/interpreter/interpreter_test.go @@ -4491,39 +4491,45 @@ func TestInterpretDictionaryIndexingType(t *testing.T) { let f = x[Type<@TestResource>()] `) - assert.Equal(t, + AssertValuesEqual(t, + inter, interpreter.NewUnmeteredSomeValueNonCopying( interpreter.NewUnmeteredStringValue("a"), ), inter.Globals.Get("a").GetValue(), ) - assert.Equal(t, + AssertValuesEqual(t, + inter, interpreter.NewUnmeteredSomeValueNonCopying( interpreter.NewUnmeteredStringValue("b"), ), inter.Globals.Get("b").GetValue(), ) - assert.Equal(t, + AssertValuesEqual(t, + inter, interpreter.NewUnmeteredSomeValueNonCopying( interpreter.NewUnmeteredStringValue("c"), ), inter.Globals.Get("c").GetValue(), ) - assert.Equal(t, + AssertValuesEqual(t, + inter, interpreter.Nil, inter.Globals.Get("d").GetValue(), ) // types need to match exactly, subtypes won't cut it - assert.Equal(t, + AssertValuesEqual(t, + inter, interpreter.Nil, inter.Globals.Get("e").GetValue(), ) - assert.Equal(t, + AssertValuesEqual(t, + inter, interpreter.NewUnmeteredSomeValueNonCopying( interpreter.NewUnmeteredStringValue("f"), ), diff --git a/runtime/tests/interpreter/memory_metering_test.go b/runtime/tests/interpreter/memory_metering_test.go index 69bef281d7..0e5e7783c3 100644 --- a/runtime/tests/interpreter/memory_metering_test.go +++ b/runtime/tests/interpreter/memory_metering_test.go @@ -644,7 +644,7 @@ func TestInterpretCompositeMetering(t *testing.T) { require.NoError(t, err) assert.Equal(t, uint64(6), meter.getMemory(common.MemoryKindStringValue)) - assert.Equal(t, uint64(66), meter.getMemory(common.MemoryKindRawString)) + assert.Equal(t, uint64(76), meter.getMemory(common.MemoryKindRawString)) assert.Equal(t, uint64(4), meter.getMemory(common.MemoryKindCompositeValueBase)) assert.Equal(t, uint64(5), meter.getMemory(common.MemoryKindAtreeMapDataSlab)) assert.Equal(t, uint64(1), meter.getMemory(common.MemoryKindAtreeMapMetaDataSlab)) @@ -767,7 +767,7 @@ func TestInterpretCompositeFieldMetering(t *testing.T) { _, err := inter.Invoke("main") require.NoError(t, err) - assert.Equal(t, uint64(16), meter.getMemory(common.MemoryKindRawString)) + assert.Equal(t, uint64(20), meter.getMemory(common.MemoryKindRawString)) assert.Equal(t, uint64(2), meter.getMemory(common.MemoryKindCompositeValueBase)) assert.Equal(t, uint64(1), meter.getMemory(common.MemoryKindAtreeMapElementOverhead)) assert.Equal(t, uint64(2), meter.getMemory(common.MemoryKindAtreeMapDataSlab)) @@ -798,7 +798,7 @@ func TestInterpretCompositeFieldMetering(t *testing.T) { _, err := inter.Invoke("main") require.NoError(t, err) - assert.Equal(t, uint64(34), meter.getMemory(common.MemoryKindRawString)) + assert.Equal(t, uint64(44), meter.getMemory(common.MemoryKindRawString)) assert.Equal(t, uint64(2), meter.getMemory(common.MemoryKindAtreeMapDataSlab)) assert.Equal(t, uint64(2), meter.getMemory(common.MemoryKindAtreeMapElementOverhead)) assert.Equal(t, uint64(0), meter.getMemory(common.MemoryKindAtreeMapMetaDataSlab)) @@ -7826,7 +7826,7 @@ func TestInterpreterStringLocationMetering(t *testing.T) { testLocationStringCount := meter.getMemory(common.MemoryKindRawString) // raw string location is "test" + locationIDs - assert.Equal(t, uint64(5), testLocationStringCount-emptyLocationStringCount) + assert.Equal(t, uint64(13), testLocationStringCount-emptyLocationStringCount) assert.Equal(t, uint64(1), meter.getMemory(common.MemoryKindCompositeStaticType))