Skip to content

Commit

Permalink
Merge pull request #195 from kitagry/add-definition
Browse files Browse the repository at this point in the history
feat: add definition
  • Loading branch information
kitagry authored Dec 22, 2024
2 parents ee094cf + 3416626 commit de44475
Show file tree
Hide file tree
Showing 10 changed files with 237 additions and 8 deletions.
22 changes: 22 additions & 0 deletions langserver/definition.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package langserver

import (
"context"
"encoding/json"

"github.com/kitagry/bqls/langserver/internal/lsp"
"github.com/sourcegraph/jsonrpc2"
)

func (h *Handler) handleTextDocumentDefinition(ctx context.Context, _ *jsonrpc2.Conn, req *jsonrpc2.Request) (result any, err error) {
if req.Params == nil {
return nil, &jsonrpc2.Error{Code: jsonrpc2.CodeInvalidParams}
}

var params lsp.TextDocumentPositionParams
if err := json.Unmarshal(*req.Params, &params); err != nil {
return nil, err
}

return h.project.LookupIdent(ctx, params.TextDocument.URI.ToURI(), params.Position)
}
2 changes: 1 addition & 1 deletion langserver/hover.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func (h *Handler) handleTextDocumentHover(ctx context.Context, conn *jsonrpc2.Co
}

func (h *Handler) documentIdent(ctx context.Context, uri lsp.DocumentURI, position lsp.Position) (lsp.Hover, error) {
result, err := h.project.TermDocument(documentURIToURI(uri), position)
result, err := h.project.TermDocument(ctx, documentURIToURI(uri), position)
if err != nil {
return lsp.Hover{}, err
}
Expand Down
1 change: 1 addition & 0 deletions langserver/initialize.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ func (h *Handler) handleInitialize(ctx context.Context, conn *jsonrpc2.Conn, req
TextDocumentSync: &lsp.TextDocumentSyncOptionsOrKind{
Kind: toPtr(lsp.TDSKFull),
},
DefinitionProvider: true,
DocumentFormattingProvider: true,
HoverProvider: true,
CodeActionProvider: true,
Expand Down
8 changes: 8 additions & 0 deletions langserver/internal/lsp/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ func (p *InitializeParams[T]) Root() DocumentURI {

type DocumentURI string

func NewDocumentURI(uri string) DocumentURI {
return DocumentURI("file://" + uri)
}

func (d DocumentURI) ToURI() string {
return string(d)[len("file://"):]
}

type ClientInfo struct {
Name string `json:"name,omitempty"`
Version string `json:"version,omitempty"`
Expand Down
58 changes: 58 additions & 0 deletions langserver/internal/source/definition.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package source

import (
"context"
"fmt"
"strings"

"github.com/goccy/go-zetasql/ast"
"github.com/kitagry/bqls/langserver/internal/lsp"
"github.com/kitagry/bqls/langserver/internal/source/file"
)

func (p *Project) LookupIdent(ctx context.Context, uri string, position lsp.Position) ([]lsp.Location, error) {
sql := p.cache.Get(uri)
parsedFile := p.analyzer.ParseFile(uri, sql.RawText)

termOffset := parsedFile.TermOffset(position)

tablePathExpression, ok := file.SearchAstNode[*ast.TablePathExpressionNode](parsedFile.Node, termOffset)
if !ok {
return nil, fmt.Errorf("not found")
}

pathExpr := tablePathExpression.PathExpr()
if pathExpr == nil {
return nil, fmt.Errorf("not found")
}

tableNames := make([]string, len(pathExpr.Names()))
for i, n := range tablePathExpression.PathExpr().Names() {
tableNames[i] = n.Name()
}
tableName := strings.Join(tableNames, ".")

withClauseEntries := file.ListAstNode[*ast.WithClauseEntryNode](parsedFile.Node)
for _, entry := range withClauseEntries {
if entry.Alias().Name() != tableName {
continue
}

locRange := entry.Alias().ParseLocationRange()
if locRange == nil {
continue
}

r, ok := parsedFile.ToLspRange(locRange)
if !ok {
continue
}

return []lsp.Location{{
URI: lsp.NewDocumentURI(uri),
Range: r,
}}, nil
}

return nil, fmt.Errorf("not found")
}
109 changes: 109 additions & 0 deletions langserver/internal/source/definition_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package source_test

import (
"context"
"errors"
"testing"

bq "cloud.google.com/go/bigquery"
"github.com/golang/mock/gomock"
"github.com/google/go-cmp/cmp"
"github.com/kitagry/bqls/langserver/internal/bigquery/mock_bigquery"
"github.com/kitagry/bqls/langserver/internal/lsp"
"github.com/kitagry/bqls/langserver/internal/source"
"github.com/kitagry/bqls/langserver/internal/source/helper"
"github.com/sirupsen/logrus"
)

func TestProject_LookupIdent(t *testing.T) {
tests := map[string]struct {
// prepare
files map[string]string
bqTableMetadata *bq.TableMetadata

// output
expectLocations []lsp.Location
expectErr error
}{
"definition to with clause": {
files: map[string]string{
"a.sql": `WITH data AS ( SELECT 1 AS a )
SELECT a FROM data|`,
},
bqTableMetadata: &bq.TableMetadata{
FullID: "project.dataset.table",
Schema: bq.Schema{},
},
expectLocations: []lsp.Location{
{
URI: lsp.NewDocumentURI("a.sql"),
Range: lsp.Range{
Start: lsp.Position{
Line: 0,
Character: 5,
},
End: lsp.Position{
Line: 0,
Character: 9,
},
},
},
},
},
"definition to 2 with clause": {
files: map[string]string{
"a.sql": `WITH data AS ( SELECT 1 AS a ),
data2 AS (SELECT * FROM data )
SELECT a FROM data2|`,
},
bqTableMetadata: &bq.TableMetadata{
FullID: "project.dataset.table",
Schema: bq.Schema{},
},
expectLocations: []lsp.Location{
{
URI: lsp.NewDocumentURI("a.sql"),
Range: lsp.Range{
Start: lsp.Position{
Line: 1,
Character: 0,
},
End: lsp.Position{
Line: 1,
Character: 5,
},
},
},
},
},
}

for n, tt := range tests {
t.Run(n, func(t *testing.T) {
ctrl := gomock.NewController(t)
bqClient := mock_bigquery.NewMockClient(ctrl)
bqClient.EXPECT().GetTableMetadata(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(tt.bqTableMetadata, nil).MinTimes(0)
logger := logrus.New()
logger.SetLevel(logrus.DebugLevel)
p := source.NewProjectWithBQClient("/", bqClient, logger)

files, path, position, err := helper.GetLspPosition(tt.files)
if err != nil {
t.Fatalf("failed to get position: %v", err)
}

for uri, content := range files {
p.UpdateFile(uri, content, 1)
}

got, err := p.LookupIdent(context.Background(), path, position)
if !errors.Is(err, tt.expectErr) {
t.Fatalf("got error %v, but want %v", err, tt.expectErr)
}

if diff := cmp.Diff(tt.expectLocations, got); diff != "" {
t.Errorf("project.TermDocument result diff (-expect, +got)\n%s", diff)
}
})
}
}
3 changes: 1 addition & 2 deletions langserver/internal/source/document.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ import (
"golang.org/x/text/message"
)

func (p *Project) TermDocument(uri string, position lsp.Position) ([]lsp.MarkedString, error) {
ctx := context.Background()
func (p *Project) TermDocument(ctx context.Context, uri string, position lsp.Position) ([]lsp.MarkedString, error) {
sql := p.cache.Get(uri)
parsedFile := p.analyzer.ParseFile(uri, sql.RawText)

Expand Down
3 changes: 2 additions & 1 deletion langserver/internal/source/document_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package source_test

import (
"context"
"errors"
"testing"
"time"
Expand Down Expand Up @@ -468,7 +469,7 @@ table description
p.UpdateFile(uri, content, 1)
}

got, err := p.TermDocument(path, position)
got, err := p.TermDocument(context.Background(), path, position)
if !errors.Is(err, tt.expectErr) {
t.Fatalf("got error %v, but want %v", err, tt.expectErr)
}
Expand Down
35 changes: 32 additions & 3 deletions langserver/internal/source/file/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,42 @@ func (p ParsedFile) ExtractSQL(locationRange *types.ParseLocationRange) (string,
return "", false
}

startOffset := p.fixTermOFfsetForSQL(locationRange.Start().ByteOffset())
endOffset := p.fixTermOFfsetForSQL(locationRange.End().ByteOffset())
startOffset := p.fixTermOffsetForSQL(locationRange.Start().ByteOffset())
endOffset := p.fixTermOffsetForSQL(locationRange.End().ByteOffset())

return p.Src[startOffset:endOffset], true
}

func (p ParsedFile) fixTermOFfsetForSQL(termOffset int) int {
func (p ParsedFile) ToLspRange(r *types.ParseLocationRange) (lsp.Range, bool) {
if r == nil {
return lsp.Range{}, false
}

return lsp.Range{
Start: p.toLspPoint(r.Start()),
End: p.toLspPoint(r.End()),
}, true
}

func (p ParsedFile) toLspPoint(point *types.ParseLocationPoint) lsp.Position {
offset := p.fixTermOffsetForSQL(point.ByteOffset())

toEndText := p.Src[:offset]
line := strings.Count(toEndText, "\n")
newLineInd := strings.LastIndex(toEndText, "\n")
var char int
if newLineInd == -1 {
char = len(toEndText)
} else {
char = len(toEndText[newLineInd:]) - 1
}
return lsp.Position{
Line: line,
Character: char,
}
}

func (p ParsedFile) fixTermOffsetForSQL(termOffset int) int {
for _, fo := range p.FixOffsets {
if termOffset > fo.Offset+fo.Length {
termOffset -= fo.Length
Expand Down
4 changes: 3 additions & 1 deletion langserver/langserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (h *Handler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2
defer func() {
err := recover()
if err != nil {
h.logger.Errorf("panic: %v", err)
h.logger.Errorf("panic: %#v", err)
}
}()
jsonrpc2.HandlerWithError(h.handle).Handle(ctx, conn, req)
Expand Down Expand Up @@ -99,6 +99,8 @@ func (h *Handler) handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2
return ignoreMiddleware(h.handleTextDocumentHover)(ctx, conn, req)
case "textDocument/completion":
return ignoreMiddleware(h.handleTextDocumentCompletion)(ctx, conn, req)
case "textDocument/definition":
return ignoreMiddleware(h.handleTextDocumentDefinition)(ctx, conn, req)
case "textDocument/codeAction":
return h.handleTextDocumentCodeAction(ctx, conn, req)
case "workspace/executeCommand":
Expand Down

0 comments on commit de44475

Please sign in to comment.