Skip to content

Commit

Permalink
refac/plugin_with_configs (#53)
Browse files Browse the repository at this point in the history
* plugin minor refactor

* passing config to OnReceive

* moved main plugins list to main

* removed unnecessary from method name

* naming changes
  • Loading branch information
danielbonilha authored Nov 10, 2022
1 parent 712ca71 commit 1883c87
Show file tree
Hide file tree
Showing 10 changed files with 311 additions and 251 deletions.
2 changes: 2 additions & 0 deletions gateway/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ func Run() {
panic(err)
}

transport.LoadPlugins()

profile := os.Getenv("PROFILE")
idProvider := idp.NewProvider(profile)

Expand Down
54 changes: 54 additions & 0 deletions gateway/plugin/plugin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package plugin

import "time"

type (
Config struct {
SessionId string
ConnectionId string
ConnectionName string
ConnectionType string
Org string
User string
Hostname string
MachineId string
KernelVersion string
ParamsData map[string]any
}
)

func (c Config) Get(key string) any {
return c.ParamsData[key]
}

func (c Config) GetByte(key string) []byte {
val, ok := c.ParamsData[key]
if !ok {
return nil
}
return val.([]byte)
}

func (c Config) GetString(key string) string {
val, ok := c.ParamsData[key]
if !ok {
return ""
}
return val.(string)
}

func (c Config) GetDuration(key string) time.Duration {
val, ok := c.ParamsData[key]
if !ok {
return 0
}
return val.(time.Duration)
}

func (c Config) GetTime(key string) *time.Time {
val, ok := c.ParamsData[key]
if !ok {
return nil
}
return val.(*time.Time)
}
16 changes: 8 additions & 8 deletions gateway/session/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ package session

import (
"fmt"
"github.com/runopsio/hoop/gateway/plugin"
"log"
"strings"

st "github.com/runopsio/hoop/gateway/storage"
"github.com/runopsio/hoop/gateway/transport/plugins"
"github.com/runopsio/hoop/gateway/user"
"olympos.io/encoding/edn"
)
Expand Down Expand Up @@ -180,17 +180,17 @@ func (s *Storage) NewGenericStorageWriter() *GenericStorageWriter {
}
}

func (s *GenericStorageWriter) Write(p plugins.ParamsData) error {
log.Printf("saving session=%v, org-id=%v\n", p.Get("session_id"), p.Get("org_id"))
func (s *GenericStorageWriter) Write(p plugin.Config) error {
log.Printf("saving session=%v, org-id=%v\n", p.SessionId, p.Org)
eventStartDate := p.GetTime("start_date")
if eventStartDate == nil {
return fmt.Errorf(`missing "start_date" param`)
}
sess := &Session{
ID: p.GetString("session_id"),
User: p.GetString("user_id"),
Type: p.GetString("connection_type"),
Connection: p.GetString("connection_name"),
ID: p.SessionId,
User: p.User,
Type: p.ConnectionType,
Connection: p.ConnectionName,
NonIndexedStream: nil,
StartSession: *eventStartDate,
EndSession: p.GetTime("end_time"),
Expand All @@ -204,6 +204,6 @@ func (s *GenericStorageWriter) Write(p plugins.ParamsData) error {
}
sess.NonIndexedStream = nonIndexedEventStream
}
_, err := s.persistFn(&user.Context{Org: &user.Org{Id: p.GetString("org_id")}}, sess)
_, err := s.persistFn(&user.Context{Org: &user.Org{Id: p.Org}}, sess)
return err
}
6 changes: 3 additions & 3 deletions gateway/transport/agent.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package transport

import (
"github.com/runopsio/hoop/gateway/plugin"
"io"
"log"
"os"
Expand All @@ -10,7 +11,6 @@ import (

pb "github.com/runopsio/hoop/common/proto"
"github.com/runopsio/hoop/gateway/agent"
"github.com/runopsio/hoop/gateway/transport/plugins"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -74,7 +74,7 @@ func (s *Server) subscribeAgent(stream pb.Transport_ConnectServer, token string)

log.Printf("successful connection hostname: [%s], machineId [%s], kernelVersion [%s]", hostname, machineId, kernelVersion)
agentErr := s.listenAgentMessages(ag, stream)
if err := s.pluginOnDisconnectPhase(plugins.ParamsData{"client": "agent"}); err != nil {
if err := s.pluginOnDisconnect(plugin.Config{ParamsData: map[string]any{"client": "agent"}}); err != nil {
log.Printf("ua=agent - failed processing plugin on-disconnect phase, err=%v", err)
}
s.disconnectAgent(ag)
Expand Down Expand Up @@ -109,7 +109,7 @@ func (s *Server) listenAgentMessages(ag *agent.Agent, stream pb.Transport_Connec
continue
}
sessionID := string(pkt.Spec[pb.SpecGatewaySessionID])
if err := s.pluginOnReceivePhase(sessionID, pkt); err != nil {
if err := s.pluginOnReceive(sessionID, pkt); err != nil {
log.Printf("plugin reject packet, err=%v", err)
return status.Errorf(codes.Internal, "internal error, plugin reject packet")
}
Expand Down
159 changes: 139 additions & 20 deletions gateway/transport/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ package transport

import (
"fmt"
"github.com/runopsio/hoop/gateway/plugin"
pluginsaudit "github.com/runopsio/hoop/gateway/transport/plugins/audit"
pluginsreview "github.com/runopsio/hoop/gateway/transport/plugins/review"
"github.com/runopsio/hoop/gateway/user"
"io"
"log"
"os"
Expand All @@ -13,7 +17,6 @@ import (
pb "github.com/runopsio/hoop/common/proto"
"github.com/runopsio/hoop/gateway/client"
"github.com/runopsio/hoop/gateway/connection"
"github.com/runopsio/hoop/gateway/transport/plugins"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
Expand All @@ -23,22 +26,51 @@ type (
connectedClients struct {
clients map[string]pb.Transport_ConnectServer
connections map[string]*connection.Connection
plugins map[string][]pluginConfig
mu sync.Mutex
}

pluginConfig struct {
Plugin
config []string
}

Plugin interface {
Name() string
OnStartup(config plugin.Config) error
OnConnect(p plugin.Config) error
OnReceive(sessionID string, config []string, packet *pb.Packet) error
OnDisconnect(p plugin.Config) error
}
)

var allPlugins []Plugin

var cc = connectedClients{
clients: make(map[string]pb.Transport_ConnectServer),
connections: make(map[string]*connection.Connection),
plugins: make(map[string][]pluginConfig),
mu: sync.Mutex{},
}

func bindClient(sessionID string, stream pb.Transport_ConnectServer, connection *connection.Connection) {
func LoadPlugins() {
allPlugins = []Plugin{
pluginsaudit.New(),
pluginsreview.New(),
}
}

func bindClient(sessionID string,
stream pb.Transport_ConnectServer,
connection *connection.Connection,
pluginsConfig []pluginConfig) {

cc.mu.Lock()
defer cc.mu.Unlock()

cc.clients[sessionID] = stream
cc.connections[sessionID] = connection
cc.plugins[sessionID] = pluginsConfig
}

func unbindClient(id string) {
Expand All @@ -47,12 +79,17 @@ func unbindClient(id string) {

delete(cc.clients, id)
delete(cc.connections, id)
delete(cc.plugins, id)
}

func getClientStream(id string) pb.Transport_ConnectServer {
return cc.clients[id]
}

func getPlugins(id string) []pluginConfig {
return cc.plugins[id]
}

func (s *Server) subscribeClient(stream pb.Transport_ConnectServer, token string) error {
ctx := stream.Context()
md, _ := metadata.FromIncomingContext(ctx)
Expand Down Expand Up @@ -95,33 +132,33 @@ func (s *Server) subscribeClient(stream pb.Transport_ConnectServer, token string
AgentId: conn.AgentId,
}

err = s.pluginOnConnectPhase(plugins.ParamsData{
"session_id": sessionID,
"connection_id": conn.Id,
"connection_name": connectionName,
"connection_type": string(conn.Type),
"org_id": context.Org.Id,
"user_id": context.User.Id,
"hostname": hostname,
"machine_id": machineId,
"kernel_version": kernelVersion,
}, context)
pConfig := plugin.Config{
SessionId: sessionID,
ConnectionId: conn.Id,
ConnectionName: connectionName,
ConnectionType: string(conn.Type),
Org: context.Org.Id,
User: context.User.Id,
Hostname: hostname,
MachineId: machineId,
KernelVersion: kernelVersion,
ParamsData: make(map[string]any),
}

plugins, err := s.loadConnectPlugins(context, pConfig)
if err != nil {
log.Printf("plugin refused to accept connection %q, err=%v", sessionID, err)
return status.Errorf(codes.FailedPrecondition, err.Error())
}

s.ClientService.Persist(c)
bindClient(c.SessionID, stream, conn)
bindClient(c.SessionID, stream, conn, plugins)

s.clientGracefulShutdown(c)

log.Printf("successful connection hostname: [%s], machineId [%s], kernelVersion [%s]", hostname, machineId, kernelVersion)
clientErr := s.listenClientMessages(stream, c, conn)
if err := s.pluginOnDisconnectPhase(plugins.ParamsData{
"org_id": context.Org.Id,
"session_id": sessionID,
}); err != nil {

if err := s.pluginOnDisconnect(pConfig); err != nil {
log.Printf("session=%v ua=client - failed processing plugin on-disconnect phase, err=%v", sessionID, err)
}

Expand Down Expand Up @@ -166,7 +203,7 @@ func (s *Server) listenClientMessages(stream pb.Transport_ConnectServer, c *clie
return status.Errorf(codes.FailedPrecondition, fmt.Sprintf("agent not found for %v", c.AgentId))
}
log.Printf("receive client packet type [%s] and session id [%s]", pkt.Type, c.SessionID)
if err := s.pluginOnReceivePhase(c.SessionID, pkt); err != nil {
if err := s.pluginOnReceive(c.SessionID, pkt); err != nil {
log.Printf("plugin reject packet, err=%v", err)
return status.Errorf(codes.Internal, "internal error, packet rejected, contact the administrator")
}
Expand Down Expand Up @@ -258,3 +295,85 @@ func (s *Server) clientGracefulShutdown(c *client.Client) {
os.Exit(143)
}()
}

func (s *Server) loadConnectPlugins(ctx *user.Context, config plugin.Config) ([]pluginConfig, error) {
pluginsConfig := make([]pluginConfig, 0)
for _, p := range allPlugins {
p1, err := s.PluginService.FindOne(ctx, p.Name())
if err != nil {
log.Printf("failed retrieving plugin %q, err=%v", p.Name(), err)
return nil, status.Errorf(codes.Internal, "failed registering plugins")
}
if p1 == nil {
log.Printf("plugin not registered %q, skipping...", p.Name())
continue
}

if p.Name() == pluginsaudit.Name {
config.ParamsData["audit_storage_writer"] = s.SessionService.Storage.NewGenericStorageWriter()
}

for _, c := range p1.Connections {
if c.Name == config.ConnectionName {
cfg := c.Config
if len(ctx.User.Groups) > 0 && len(c.Groups) > 0 {
cfg = make([]string, 0)
for _, u := range ctx.User.Groups {
cfg = append(cfg, c.Groups[u]...)
}
}
cfg = removeDuplicates(cfg)
ep := pluginConfig{
Plugin: p,
config: cfg,
}

if err := p.OnStartup(config); err != nil {
log.Printf("failed starting plugin %q, err=%v", p.Name(), err)
return nil, status.Errorf(codes.Internal, "failed starting plugin")
}

if err = p.OnConnect(config); err != nil {
log.Printf("plugin %q refused to accept connection %q, err=%v", p1.Name, config.SessionId, err)
return nil, status.Errorf(codes.FailedPrecondition, err.Error())
}

pluginsConfig = append(pluginsConfig, ep)
break
}
}
}
return pluginsConfig, nil
}

func (s *Server) pluginOnDisconnect(config plugin.Config) error {
plugins := getPlugins(config.SessionId)
for _, p := range plugins {
return p.OnDisconnect(config)
}
return nil
}

func (s *Server) pluginOnReceive(sessionID string, pkt *pb.Packet) error {
plugins := getPlugins(sessionID)
for _, p := range plugins {
if err := p.OnReceive(sessionID, p.config, pkt); err != nil {
log.Printf("session=%v - plugin %q rejected packet, err=%v",
sessionID, p.Name(), err)
return err
}
}
return nil
}

func removeDuplicates(strSlice []string) []string {
allKeys := make(map[string]bool)
list := make([]string, 0)
for _, item := range strSlice {
if _, value := allKeys[item]; !value {
allKeys[item] = true
list = append(list, item)
}
}
return list
}
Loading

0 comments on commit 1883c87

Please sign in to comment.