Skip to content

Commit

Permalink
Provide helpful error messages and status codes for BGS endpoints (#355)
Browse files Browse the repository at this point in the history
  • Loading branch information
ericvolp12 authored Oct 2, 2023
2 parents 11b4a36 + b02b058 commit 6acd4e5
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 29 deletions.
58 changes: 34 additions & 24 deletions bgs/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (

atproto "github.com/bluesky-social/indigo/api/atproto"
comatprototypes "github.com/bluesky-social/indigo/api/atproto"
"github.com/bluesky-social/indigo/blobs"
"github.com/bluesky-social/indigo/mst"
"gorm.io/gorm"

"github.com/bluesky-social/indigo/util"
Expand All @@ -26,6 +28,7 @@ func (s *BGS) handleComAtprotoSyncGetRecord(ctx context.Context, collection stri
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, echo.NewHTTPError(http.StatusNotFound, "user not found")
}
log.Errorw("failed to lookup user", "err", err, "did", did)
return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to lookup user")
}

Expand All @@ -41,19 +44,25 @@ func (s *BGS) handleComAtprotoSyncGetRecord(ctx context.Context, collection stri
if commit != "" {
reqCid, err = cid.Decode(commit)
if err != nil {
return nil, fmt.Errorf("failed to decode commit cid: %w", err)
log.Errorw("failed to decode commit cid", "err", err, "cid", commit)
return nil, echo.NewHTTPError(http.StatusBadRequest, "failed to decode commit cid")
}
}

_, record, err := s.repoman.GetRecord(ctx, u.ID, collection, rkey, reqCid)
if err != nil {
return nil, fmt.Errorf("failed to get record: %w", err)
if errors.Is(err, mst.ErrNotFound) {
return nil, echo.NewHTTPError(http.StatusNotFound, "record not found in repo")
}
log.Errorw("failed to get record from repo", "err", err, "did", did, "collection", collection, "rkey", rkey)
return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to get record from repo")
}

buf := new(bytes.Buffer)
err = record.MarshalCBOR(buf)
if err != nil {
return nil, fmt.Errorf("failed to marshal record: %w", err)
log.Errorw("failed to marshal record to CBOR", "err", err, "did", did, "collection", collection, "rkey", rkey)
return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to marshal record to CBOR")
}

return buf, nil
Expand All @@ -65,6 +74,7 @@ func (s *BGS) handleComAtprotoSyncGetRepo(ctx context.Context, did string, since
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, echo.NewHTTPError(http.StatusNotFound, "user not found")
}
log.Errorw("failed to lookup user", "err", err, "did", did)
return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to lookup user")
}

Expand All @@ -79,7 +89,8 @@ func (s *BGS) handleComAtprotoSyncGetRepo(ctx context.Context, did string, since
// TODO: stream the response
buf := new(bytes.Buffer)
if err := s.repoman.ReadRepo(ctx, u.ID, since, buf); err != nil {
return nil, fmt.Errorf("failed to read repo: %w", err)
log.Errorw("failed to read repo into buffer", "err", err, "did", did)
return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to read repo into buffer")
}

return buf, nil
Expand All @@ -92,27 +103,21 @@ func (s *BGS) handleComAtprotoSyncGetBlocks(ctx context.Context, cids []string,
func (s *BGS) handleComAtprotoSyncRequestCrawl(ctx context.Context, body *comatprototypes.SyncRequestCrawl_Input) error {
host := body.Hostname
if host == "" {
return fmt.Errorf("must pass valid hostname")
return echo.NewHTTPError(http.StatusBadRequest, "must pass hostname")
}

if strings.HasPrefix(host, "https://") || strings.HasPrefix(host, "http://") {
return &echo.HTTPError{
Code: 400,
Message: "must pass domain without protocol scheme",
}
return echo.NewHTTPError(http.StatusBadRequest, "must pass domain without protocol scheme")
}

norm, err := util.NormalizeHostname(host)
if err != nil {
return err
return echo.NewHTTPError(http.StatusBadRequest, "failed to normalize hostname")
}

banned, err := s.domainIsBanned(ctx, host)
if banned {
return &echo.HTTPError{
Code: 401,
Message: "domain is banned",
}
return echo.NewHTTPError(http.StatusUnauthorized, "domain is banned")
}

log.Warnf("TODO: better host validation for crawl requests")
Expand All @@ -128,10 +133,7 @@ func (s *BGS) handleComAtprotoSyncRequestCrawl(ctx context.Context, body *comatp

desc, err := atproto.ServerDescribeServer(ctx, c)
if err != nil {
return &echo.HTTPError{
Code: 401,
Message: fmt.Sprintf("given host failed to respond to ping: %s", err),
}
return echo.NewHTTPError(http.StatusBadRequest, "requested host failed to respond to describe request")
}

// Maybe we could do something with this response later
Expand All @@ -152,7 +154,11 @@ func (s *BGS) handleComAtprotoSyncGetBlob(ctx context.Context, cid string, did s

b, err := s.blobs.GetBlob(ctx, cid, did)
if err != nil {
return nil, err
if errors.Is(err, blobs.NotFoundErr) {
return nil, echo.NewHTTPError(http.StatusNotFound, "blob not found")
}
log.Errorw("failed to get blob", "err", err, "cid", cid, "did", did)
return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to get blob")
}

return bytes.NewReader(b), nil
Expand All @@ -169,7 +175,7 @@ func (s *BGS) handleComAtprotoSyncListRepos(ctx context.Context, cursor string,
if cursor != "" {
c, err = strconv.ParseInt(cursor, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid cursor: %w", err)
return nil, echo.NewHTTPError(http.StatusBadRequest, "couldn't parse your cursor as an integer")
}
}

Expand All @@ -178,7 +184,8 @@ func (s *BGS) handleComAtprotoSyncListRepos(ctx context.Context, cursor string,
if err == gorm.ErrRecordNotFound {
return &comatprototypes.SyncListRepos_Output{}, nil
}
return nil, fmt.Errorf("failed to get users: %w", err)
log.Errorw("failed to query users", "err", err)
return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to query users")
}

if len(users) == 0 {
Expand All @@ -194,7 +201,8 @@ func (s *BGS) handleComAtprotoSyncListRepos(ctx context.Context, cursor string,

root, err := s.repoman.GetRepoRoot(ctx, user.ID)
if err != nil {
return nil, fmt.Errorf("failed to get repo root for (%s): %w", user.Did, err)
log.Errorw("failed to get repo root", "err", err, "did", user.Did)
return nil, echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to get repo root for (%s): %v", user.Did, err.Error()))
}

resp.Repos = append(resp.Repos, &comatprototypes.SyncListRepos_Repo{
Expand Down Expand Up @@ -229,12 +237,14 @@ func (s *BGS) handleComAtprotoSyncGetLatestCommit(ctx context.Context, did strin

root, err := s.repoman.GetRepoRoot(ctx, u.ID)
if err != nil {
return nil, fmt.Errorf("failed to get repo root: %w", err)
log.Errorw("failed to get repo root", "err", err, "did", u.Did)
return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to get repo root")
}

rev, err := s.repoman.GetRepoRev(ctx, u.ID)
if err != nil {
return nil, fmt.Errorf("failed to get repo rev: %w", err)
log.Errorw("failed to get repo rev", "err", err, "did", u.Did)
return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to get repo rev")
}

return &comatprototypes.SyncGetLatestCommit_Output{
Expand Down
11 changes: 11 additions & 0 deletions blobs/blobs.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ package blobs

import (
"context"
"fmt"
"os"
"path/filepath"
)

var NotFoundErr = fmt.Errorf("blob not found")

type BlobStore interface {
PutBlob(ctx context.Context, cid string, did string, blob []byte) error
GetBlob(ctx context.Context, cid string, did string) ([]byte, error)
Expand All @@ -25,5 +28,13 @@ func (dbs *DiskBlobStore) PutBlob(ctx context.Context, cid string, did string, b
}

func (dbs *DiskBlobStore) GetBlob(ctx context.Context, cid string, did string) ([]byte, error) {
// Check if the blob exists
_, err := os.Stat(filepath.Join(dbs.Dir, did, cid))
if err != nil {
if os.IsNotExist(err) {
return nil, NotFoundErr
}
return nil, err
}
return os.ReadFile(filepath.Join(dbs.Dir, did, cid))
}
14 changes: 9 additions & 5 deletions testing/integ_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -380,18 +380,22 @@ func TestDomainBans(t *testing.T) {
t.Fatal("domain should be banned")
}

if err := atproto.SyncRequestCrawl(context.TODO(), c, &atproto.SyncRequestCrawl_Input{Hostname: "app.pds.foo.com"}); err == nil {
err := atproto.SyncRequestCrawl(context.TODO(), c, &atproto.SyncRequestCrawl_Input{Hostname: "app.pds.foo.com"})
if err == nil {
t.Fatal("domain should be banned")
}

if !strings.Contains(err.Error(), "XRPC ERROR 401") {
t.Fatal("should have failed with a 401")
}

// should not be banned
err := atproto.SyncRequestCrawl(context.TODO(), c, &atproto.SyncRequestCrawl_Input{Hostname: "foo.bar.com"})
err = atproto.SyncRequestCrawl(context.TODO(), c, &atproto.SyncRequestCrawl_Input{Hostname: "foo.bar.com"})
if err == nil {
t.Fatal("should still fail")
}

if !strings.Contains(err.Error(), "XRPC ERROR 401") {
t.Fatal("should have failed with a 401")
if !strings.Contains(err.Error(), "XRPC ERROR 400") {
t.Fatal("should have failed with a 400")
}

}

0 comments on commit 6acd4e5

Please sign in to comment.