Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LSP Notification Message Type #749

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 54 additions & 22 deletions internal/lsp/lsproto/jsonrpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,60 +43,92 @@ func (id *ID) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, &id.int)
}

// TODO(jakebailey): NotificationMessage? Use RequestMessage without ID?
type Message struct {
JSONRPC JSONRPCVersion `json:"jsonrpc"`
}

type NotificationMessage struct {
Message
Method Method `json:"method"`
Params any `json:"params"`
}

type RequestMessage struct {
JSONRPC JSONRPCVersion `json:"jsonrpc"`
ID *ID `json:"id"`
Method Method `json:"method"`
Params any `json:"params"`
Message
ID *ID `json:"id"`
Method Method `json:"method"`
Params any `json:"params"`
}

func (r *RequestMessage) UnmarshalJSON(data []byte) error {
type RequestOrNotificationMessage struct {
NotificationMessage *NotificationMessage
RequestMessage *RequestMessage
}

func (r *RequestOrNotificationMessage) UnmarshalJSON(data []byte) error {
var raw struct {
JSONRPC JSONRPCVersion `json:"jsonrpc"`
ID *ID `json:"id"`
ID *ID `json:"id,omitzero"`
Method Method `json:"method"`
Params json.RawMessage `json:"params"`
}
if err := json.Unmarshal(data, &raw); err != nil {
return fmt.Errorf("%w: %w", ErrInvalidRequest, err)
}

r.ID = raw.ID
r.Method = raw.Method
if r.Method == MethodShutdown || r.Method == MethodExit {
params, err := unmarshalParams(raw.Method, raw.Params)
if err != nil {
return err
}

if raw.ID != nil {
r.RequestMessage = &RequestMessage{
ID: raw.ID,
Method: raw.Method,
Params: params,
}
} else {
r.NotificationMessage = &NotificationMessage{
Method: raw.Method,
Params: params,
}
}

return nil
}

func unmarshalParams(rawMethod Method, rawParams []byte) (any, error) {
if rawMethod == MethodShutdown || rawMethod == MethodExit {
// These methods have no params.
return nil
return nil, nil
}

var params any
var err error

if unmarshalParams, ok := unmarshallers[raw.Method]; ok {
params, err = unmarshalParams(raw.Params)
if unmarshaller, ok := unmarshallers[rawMethod]; ok {
params, err = unmarshaller(rawParams)
} else {
// Fall back to default; it's probably an unknown message and we will probably not handle it.
err = json.Unmarshal(raw.Params, &params)
err = json.Unmarshal(rawParams, &params)
}
r.Params = params

if err != nil {
return fmt.Errorf("%w: %w", ErrInvalidRequest, err)
return nil, fmt.Errorf("%w: %w", ErrInvalidRequest, err)
}

return nil
return params, nil
}

type ResponseMessage struct {
JSONRPC JSONRPCVersion `json:"jsonrpc"`
ID *ID `json:"id,omitempty"`
Result any `json:"result"`
Error *ResponseError `json:"error,omitempty"`
Message
ID *ID `json:"id,omitzero"`
Result any `json:"result"`
Error *ResponseError `json:"error,omitzero"`
}

type ResponseError struct {
Code int32 `json:"code"`
Message string `json:"message"`
Data any `json:"data,omitempty"`
Data any `json:"data,omitzero"`
}
102 changes: 55 additions & 47 deletions internal/lsp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,20 @@ func (s *Server) Run() error {
}

if s.initializeParams == nil {
if req.Method == lsproto.MethodInitialize {
if err := s.handleInitialize(req); err != nil {
return err
}
} else {
if err := s.sendError(req.ID, lsproto.ErrServerNotInitialized); err != nil {
return err
if req.RequestMessage != nil {
message := req.RequestMessage

if message.Method == lsproto.MethodInitialize {
if err := s.handleInitialize(message); err != nil {
return err
}
} else {
if err := s.sendError(message.ID, lsproto.ErrServerNotInitialized); err != nil {
return err
}
}
}

continue
}

Expand All @@ -123,13 +128,13 @@ func (s *Server) Run() error {
}
}

func (s *Server) read() (*lsproto.RequestMessage, error) {
func (s *Server) read() (*lsproto.RequestOrNotificationMessage, error) {
data, err := s.r.Read()
if err != nil {
return nil, err
}

req := &lsproto.RequestMessage{}
req := &lsproto.RequestOrNotificationMessage{}
if err := json.Unmarshal(data, req); err != nil {
return nil, fmt.Errorf("%w: %w", lsproto.ErrInvalidRequest, err)
}
Expand Down Expand Up @@ -170,45 +175,45 @@ func (s *Server) sendResponse(resp *lsproto.ResponseMessage) error {
return s.w.Write(data)
}

func (s *Server) handleMessage(req *lsproto.RequestMessage) error {
s.requestTime = time.Now()
s.requestMethod = string(req.Method)

params := req.Params
switch params.(type) {
case *lsproto.InitializeParams:
return s.sendError(req.ID, lsproto.ErrInvalidRequest)
case *lsproto.InitializedParams:
return s.handleInitialized(req)
case *lsproto.DidOpenTextDocumentParams:
return s.handleDidOpen(req)
case *lsproto.DidChangeTextDocumentParams:
return s.handleDidChange(req)
case *lsproto.DidSaveTextDocumentParams:
return s.handleDidSave(req)
case *lsproto.DidCloseTextDocumentParams:
return s.handleDidClose(req)
case *lsproto.DocumentDiagnosticParams:
return s.handleDocumentDiagnostic(req)
case *lsproto.HoverParams:
return s.handleHover(req)
case *lsproto.DefinitionParams:
return s.handleDefinition(req)
default:
func (s *Server) handleMessage(msg *lsproto.RequestOrNotificationMessage) error {
if req := msg.RequestMessage; req != nil {
switch req.Method {
case lsproto.MethodInitialize:
return s.sendError(req.ID, lsproto.ErrInvalidRequest)
case lsproto.MethodTextDocumentDiagnostic:
return s.handleDocumentDiagnostic(req)
case lsproto.MethodTextDocumentHover:
return s.handleHover(req)
case lsproto.MethodTextDocumentDefinition:
return s.handleDefinition(req)
case lsproto.MethodShutdown:
s.projectService.Close()
return s.sendResult(req.ID, nil)
case lsproto.MethodExit:
return nil
default:
s.Log("unknown method", req.Method)
if req.ID != nil {
return s.sendError(req.ID, lsproto.ErrInvalidRequest)
}
}
} else if notif := msg.NotificationMessage; notif != nil {
switch notif.Method {
case lsproto.MethodInitialized:
return s.handleInitialized()
case lsproto.MethodTextDocumentDidOpen:
return s.handleDidOpen(notif)
case lsproto.MethodTextDocumentDidChange:
return s.handleDidChange(notif)
case lsproto.MethodTextDocumentDidSave:
return s.handleDidSave(notif)
case lsproto.MethodTextDocumentDidClose:
return s.handleDidClose(notif)
case lsproto.MethodExit:
return nil
default:
s.Log("unknown method", notif.Method)
}
} else {
s.Log("Failed to parse unknown message")
}

return nil
}

func (s *Server) handleInitialize(req *lsproto.RequestMessage) error {
Expand Down Expand Up @@ -254,7 +259,7 @@ func (s *Server) handleInitialize(req *lsproto.RequestMessage) error {
})
}

func (s *Server) handleInitialized(req *lsproto.RequestMessage) error {
func (s *Server) handleInitialized() error {
s.logger = project.NewLogger([]io.Writer{s.stderr}, project.LogLevelVerbose)
s.projectService = project.NewService(s, project.ServiceOptions{
DefaultLibraryPath: s.defaultLibraryPath,
Expand All @@ -269,24 +274,26 @@ func (s *Server) handleInitialized(req *lsproto.RequestMessage) error {
return nil
}

func (s *Server) handleDidOpen(req *lsproto.RequestMessage) error {
func (s *Server) handleDidOpen(req *lsproto.NotificationMessage) error {
params := req.Params.(*lsproto.DidOpenTextDocumentParams)
s.projectService.OpenFile(ls.DocumentURIToFileName(params.TextDocument.Uri), params.TextDocument.Text, ls.LanguageKindToScriptKind(params.TextDocument.LanguageId), "")
return nil
}

func (s *Server) handleDidChange(req *lsproto.RequestMessage) error {
func (s *Server) handleDidChange(req *lsproto.NotificationMessage) error {
params := req.Params.(*lsproto.DidChangeTextDocumentParams)
scriptInfo := s.projectService.GetScriptInfo(ls.DocumentURIToFileName(params.TextDocument.Uri))
if scriptInfo == nil {
return s.sendError(req.ID, lsproto.ErrRequestFailed)
s.logger.Error("Failed to get script info")
return nil
}

changes := make([]ls.TextChange, len(params.ContentChanges))
for i, change := range params.ContentChanges {
if partialChange := change.TextDocumentContentChangePartial; partialChange != nil {
if textChange, err := s.converters.FromLSPTextChange(partialChange, scriptInfo.FileName()); err != nil {
return s.sendError(req.ID, err)
s.logger.Error(fmt.Sprintf("Error converting %v:", err))
return nil
} else {
changes[i] = textChange
}
Expand All @@ -296,21 +303,22 @@ func (s *Server) handleDidChange(req *lsproto.RequestMessage) error {
NewText: wholeChange.Text,
}
} else {
return s.sendError(req.ID, lsproto.ErrInvalidRequest)
s.logger.Error(fmt.Sprintf("Invalid request"))
return nil
}
}

s.projectService.ChangeFile(ls.DocumentURIToFileName(params.TextDocument.Uri), changes)
return nil
}

func (s *Server) handleDidSave(req *lsproto.RequestMessage) error {
func (s *Server) handleDidSave(req *lsproto.NotificationMessage) error {
params := req.Params.(*lsproto.DidSaveTextDocumentParams)
s.projectService.MarkFileSaved(ls.DocumentURIToFileName(params.TextDocument.Uri), *params.Text)
return nil
}

func (s *Server) handleDidClose(req *lsproto.RequestMessage) error {
func (s *Server) handleDidClose(req *lsproto.NotificationMessage) error {
params := req.Params.(*lsproto.DidCloseTextDocumentParams)
s.projectService.CloseFile(ls.DocumentURIToFileName(params.TextDocument.Uri))
return nil
Expand Down