Skip to content

Commit

Permalink
updates for ndb context caching everything: Object.new/changed logic,…
Browse files Browse the repository at this point in the history
… tests

for #1149, 18aa302
  • Loading branch information
snarfed committed Jan 14, 2025
1 parent a756edb commit 32a60c9
Show file tree
Hide file tree
Showing 11 changed files with 39 additions and 61 deletions.
6 changes: 3 additions & 3 deletions common.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,12 +375,12 @@ def report_error(msg, *, exception=False, **kwargs):


def cache_policy(key):
"""In memory ndb cache, only DID docs right now.
"""In memory ndb cache. Cache everything!
https://github.com/snarfed/bridgy-fed/issues/1149#issuecomment-2261383697
Avoid caching much more due to this bug where unstored in-memory
modifications get returned by later gets:
Keep an eye on this in case we start seeing problems due to this ndb bug
where unstored in-memory modifications get returned by later gets:
https://github.com/googleapis/python-ndb/issues/888
Args:
Expand Down
3 changes: 1 addition & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,9 @@
for logger in ('atproto_firehose', 'lexrpc', 'oauth_dropins.webutil.webmention'):
logging.getLogger(logger).setLevel(logging.DEBUG)

logging.getLogger('google.cloud').propagate = True

# for debugging ndb. also needs NDB_DEBUG env var, set in *.yaml.
# https://github.com/googleapis/python-ndb/blob/c55ec62b5153787404488b046c4bf6ffa02fee64/google/cloud/ndb/utils.py#L78-L81
# logging.getLogger('google.cloud').propagate = True
# logging.getLogger('google.cloud.ndb').setLevel(logging.DEBUG)
# logging.getLogger('google.cloud.ndb._cache').setLevel(logging.DEBUG)
# logging.getLogger('google.cloud.ndb.global_cache').setLevel(logging.DEBUG)
Expand Down
8 changes: 7 additions & 1 deletion models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,13 +1095,19 @@ def get_or_create(cls, id, authed_as=None, **props):
obj = Object(id=id)
obj.new = True

obj.changed = None
for field in 'new', 'changed':
val = props.pop(field, None)
if val is not None:
setattr(obj, field, val)

if set(props.keys()) & set(('as2', 'bsky', 'mf2', 'raw')):
obj.clear()
obj.populate(**{
k: v for k, v in props.items()
if v and not isinstance(getattr(Object, k), ndb.ComputedProperty)
})
if not obj.new:
if not obj.new and obj.changed is None:
obj.changed = obj.activity_changed(orig_as1)

obj.put()
Expand Down
21 changes: 9 additions & 12 deletions protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,12 +907,9 @@ def receive(from_cls, obj, authed_as=None, internal=False, received_at=None):
error(f'Actor {actor} is opted out or blocked', status=204)

# write Object to datastore
orig = obj
obj = Object.get_or_create(id, authed_as=actor, **orig.to_dict())
if orig.new is not None:
obj.new = orig.new
if orig.changed is not None:
obj.changed = orig.changed
orig_props = obj.to_dict()
obj = Object.get_or_create(id, new=obj.new, changed=obj.changed,
authed_as=actor, **orig_props)

# if this is an object, ie not an activity, wrap it in a create or update
obj = from_cls.handle_bare_object(obj, authed_as=authed_as)
Expand Down Expand Up @@ -1558,23 +1555,23 @@ def load(cls, id, remote=None, local=True, raise_=True, **kwargs):
"""
assert id
assert local or remote is not False
# logger.debug(f'Loading Object {id} local={local} remote={remote}')
logger.debug(f'Loading Object {id} local={local} remote={remote}')

obj = orig_as1 = None
if local and not obj:
if local:
obj = Object.get_by_id(id)
if not obj:
# logger.debug(f' not in datastore')
logger.debug(f' {id} not in datastore')
pass
elif obj.as1 or obj.raw or obj.deleted:
# logger.debug(' got from datastore')
logger.debug(f' {id} got from datastore')
obj.new = False

if remote is False:
return obj
elif remote is None and obj:
if obj.updated < util.as_utc(util.now() - OBJECT_REFRESH_AGE):
# logger.debug(f' last updated {obj.updated}, refreshing')
logger.debug(f' last updated {obj.updated}, refreshing')
pass
else:
return obj
Expand All @@ -1586,7 +1583,7 @@ def load(cls, id, remote=None, local=True, raise_=True, **kwargs):
else:
obj = Object(id=id)
if local:
# logger.debug(' not in datastore')
logger.debug(f' {id} not in datastore')
obj.new = True
obj.changed = False

Expand Down
4 changes: 3 additions & 1 deletion tests/test_activitypub.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,6 +1301,7 @@ def test_inbox_follow_use_instead_strip_www(self, mock_head, mock_get, mock_post
mock_get.side_effect = [
# source actor
self.as2_resp(ACTOR),
self.as2_resp(ACTOR),
# target user
test_web.ACTOR_HTML_RESP,
# target post webmention discovery
Expand Down Expand Up @@ -1333,6 +1334,7 @@ def test_inbox_follow_web_brid_gy_subdomain(self, mock_head, mock_get, mock_post
mock_get.side_effect = [
# source actor
self.as2_resp(ACTOR),
self.as2_resp(ACTOR),
# target user
test_web.ACTOR_HTML_RESP,
# target post webmention discovery
Expand Down Expand Up @@ -1780,11 +1782,11 @@ def test_inbox_no_webmention_endpoint(self, mock_head, mock_get, mock_post):

def test_inbox_id_already_seen(self, mock_head, mock_get, mock_post):
mock_get.side_effect = [
self.as2_resp(ACTOR),
self.as2_resp(ACTOR),
HTML,
]


obj_key = Object(id=FOLLOW_WRAPPED['id'], as2={}).put()

got = self.post('/user.com/inbox', json=FOLLOW_WRAPPED)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_atproto.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dns.resolver import NXDOMAIN
import google.cloud.dns.client
from google.cloud.dns.zone import ManagedZone
from google.cloud import ndb
from google.cloud.tasks_v2.types import Task
from granary import bluesky
from granary.tests.test_bluesky import (
Expand Down Expand Up @@ -1781,6 +1782,7 @@ def test_send_update_wrong_repo(self):
_, _, rkey = arroba.util.parse_at_uri(orig.copies[0].uri)
orig.copies[0].uri = orig.copies[0].uri.replace('did:plc:user', 'did:plc:eve')
orig.put()
ndb.context.get_context().cache.clear()

update = Object(id='fake:update', source_protocol='fake', our_as1={
'objectType': 'activity',
Expand Down
2 changes: 2 additions & 0 deletions tests/test_atproto_firehose.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from carbox import read_car, write_car
from carbox.car import Block
import dag_cbor
from google.cloud import ndb
from google.cloud.tasks_v2.types import Task
from granary.tests.test_bluesky import (
ACTOR_PROFILE_BSKY,
Expand Down Expand Up @@ -484,6 +485,7 @@ def _now(tz=None):
+ STORE_CURSOR_FREQ - timedelta(seconds=1))
FakeWebsocketClient.setup_receive(op)
self.subscribe()
ndb.context.get_context().cache.clear()
self.assertEqual(444, self.cursor.key.get().cursor)

# now it's been long enough
Expand Down
5 changes: 3 additions & 2 deletions tests/test_follow.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,9 @@ def test_callback_stored_followee_with_our_as1(self, mock_get, mock_post):
source_protocol='activitypub')

mock_get.side_effect = (
requests_response(''), # alice.com h-card
requests_response(''), # indieauth alice.com fetch for user json
requests_response(''), # alice.com h-card
self.as2_resp(FOLLOWEE),
requests_response(''), # indieauth alice.com fetch for user json
)
mock_post.side_effect = (
requests_response('me=https://alice.com'),
Expand Down
22 changes: 1 addition & 21 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def test_load_multi(self):
bob = Fake(id='bob.com', obj_key=Object(id='bob').key)
bob.put()

user = self.user.key.get()
user = self.user.key.get(use_cache=False)
self.assertFalse(hasattr(user, '_obj'))
self.assertFalse(hasattr(alice, '_obj'))
self.assertIsNone(bob._obj)
Expand Down Expand Up @@ -580,26 +580,6 @@ def test_target_hashable(self):
# just check that these don't crash
assert isinstance(id(target), int)

def test_ndb_in_memory_cache_off(self):
"""It has a weird bug that we want to avoid.
https://github.com/googleapis/python-ndb/issues/888
"""
from google.cloud.ndb import Model, StringProperty
class Foo(Model):
a = StringProperty()

f = Foo(id='x', a='asdf')
f.put()
# print(id(f))

f.a = 'qwert'

got = Foo.get_by_id('x')
# print(got)
# print(id(got))
self.assertEqual('asdf', got.a)

def test_get_or_create(self):
def check(obj1, obj2):
self.assert_entities_equal(obj1, obj2, ignore=['expire', 'updated'])
Expand Down
18 changes: 6 additions & 12 deletions tests/test_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,6 @@
]
def web_user_gets(domain='user.com'):
return [
requests_response(ACTOR_HTML, url=f'https://{domain}/'),
requests_response(ACTOR_HTML, url=f'https://{domain}/'),
requests_response(status=404), # webfinger
]
Expand Down Expand Up @@ -597,8 +596,6 @@ def test_get_or_create_new_propagate_atproto(self, mock_create_task,
mock_get.side_effect = web_user_gets('new.com')
html = requests_response(ACTOR_HTML, url='https://new.com/')
mock_get.side_effect = [
html,
html,
html,
requests_response(status=404), # webfinger
]
Expand Down Expand Up @@ -1315,7 +1312,7 @@ def test_create_post_use_instead_strip_www(self, mock_get, mock_post):
self.assert_deliveries(mock_post, inboxes, create_as2)

def test_create_post(self, mock_get, mock_post):
mock_get.side_effect = [NOTE, ACTOR]
mock_get.return_value = NOTE
mock_post.return_value = requests_response('abc xyz')
self.make_followers()

Expand Down Expand Up @@ -1345,11 +1342,11 @@ def test_create_post(self, mock_get, mock_post):
)

def test_update_post(self, mock_get, mock_post):
mock_get.side_effect = [NOTE, ACTOR]
mock_get.return_value = NOTE
mock_post.return_value = requests_response('abc xyz')

mf2 = copy.deepcopy(NOTE_MF2)
mf2['properties']['content'] = 'different'
mf2['properties']['content'] = ['different']
Object(id='https://user.com/post', users=[self.user.key], mf2=mf2).put()

self.make_followers()
Expand Down Expand Up @@ -1414,7 +1411,7 @@ def test_create_with_image(self, mock_get, mock_post):
self.assert_equals(create, json_loads(mock_post.call_args[1]['data']))

def test_follow(self, mock_get, mock_post):
mock_get.side_effect = [FOLLOW, ACTOR, WEBMENTION_REL_LINK]
mock_get.side_effect = [FOLLOW, ACTOR, ACTOR, WEBMENTION_REL_LINK]
mock_post.return_value = requests_response('abc xyz')

got = self.post('/queue/webmention', data={
Expand Down Expand Up @@ -1513,9 +1510,7 @@ def test_follow_fragment(self, mock_get, mock_post):
mock_get.side_effect = [
FOLLOW_FRAGMENT,
ACTOR,
FOLLOW_FRAGMENT, # protocol detection: AS2
FOLLOW_FRAGMENT, # protocol detection: HTML
ACTOR, # authorship
ACTOR,
]
mock_post.return_value = requests_response('abc xyz')

Expand Down Expand Up @@ -1797,6 +1792,7 @@ def test_update_profile(self, mock_get, mock_post):
expected_as2)

# updated Web user
ndb.context.get_context().cache.clear()
expected_actor_as2 = {
'@context': [
'https://www.w3.org/ns/activitystreams',
Expand Down Expand Up @@ -2890,8 +2886,6 @@ def test_check_web_site_unicode_domain(self, mock_get, _):

def test_check_web_site_lower_cases_domain(self, mock_get, _):
gets = [
requests_response(ACTOR_HTML, url='https://abc.org/'),
requests_response(ACTOR_HTML, url='https://abc.org/'),
requests_response(ACTOR_HTML, url='https://abc.org/'),
requests_response(status=404), # webfinger
]
Expand Down
9 changes: 2 additions & 7 deletions tests/testutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from bs4 import MarkupResemblesLocatorWarning
import dag_cbor.random
from google.cloud import ndb
from google.cloud.ndb.global_cache import _InProcessGlobalCache
from google.protobuf.timestamp_pb2 import Timestamp
from granary import as1, as2
from granary.tests.test_as1 import (
Expand All @@ -35,7 +34,7 @@

# other modules are imported _after_ Fake etc classes is defined so that it's in
# PROTOCOLS when URL routes are registered.
from common import long_to_base64, TASKS_LOCATION
from common import long_to_base64, NDB_CONTEXT_KWARGS, TASKS_LOCATION
import ids
import models
from models import KEY_BITS, Object, PROTOCOLS, Target, User
Expand Down Expand Up @@ -333,11 +332,7 @@ def setUp(self):

# clear datastore
requests.post(f'http://{ndb_client.host}/reset')
self.ndb_context = ndb_client.context(
cache_policy=common.cache_policy,
global_cache=_InProcessGlobalCache(),
global_cache_policy=common.global_cache_policy,
global_cache_timeout_policy=common.global_cache_timeout_policy)
self.ndb_context = ndb_client.context(**NDB_CONTEXT_KWARGS)
self.ndb_context.__enter__()

util.now = lambda **kwargs: testutil.NOW
Expand Down

0 comments on commit 32a60c9

Please sign in to comment.