diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..491704a --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,15 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file + +version: 2 +updates: + - package-ecosystem: "gomod" # See documentation for possible values + directory: "/core" # Location of package manifests + schedule: + interval: "daily" + - package-ecosystem: "npm" + directory: "/frontend" + schedule: + interval: "daily" diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d8a1ca2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/.idea +/.vscode +/dev/elasticsearch \ No newline at end of file diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..ac0e1d4 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,128 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +support@clidey.com. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct +enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..161d8e7 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,91 @@ +# Contributing + +When contributing to this repository, please first discuss the change you wish to make via issue, +email, or any other method with the owners of this repository before making a change. + +Please note we have a code of conduct, please follow it in all your interactions with the project. + +## Pull Request Process + +1. Please do not include any build artifacts unless absolutely necessary. +2. Update the README.md with details of changes to the interface, this includes new environment + variables, exposed ports, useful file locations and container parameters. +3. You're free to choose what you want to work on - there will always be existing issues and we welcome all feature requests. +4. If you choose to work on an issue, please check on the thread if there has been any activity - if someone's working on something, check with them and see if they'd like help, or if they're still working on it. +5. If you choose to add a new feature, please open an issue with a title that starts with [FR] and we can have a quick discussion to see viability! +6. When you're ready, open a pull request, add a detailed description of what you've done, and we'll be in touch! Do not hesitate to update it if you find you missed/wanted to add something. + +## Code of Conduct + +### Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to making participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, gender identity and expression, level of experience, +nationality, personal appearance, race, religion, or sexual identity and +orientation. + +### Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +### Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +### Scope + +This Code of Conduct applies both within project spaces and in public spaces +when an individual is representing the project or its community. Examples of +representing a project or community include using an official project e-mail +address, posting via an official social media account, or acting as an appointed +representative at an online or offline event. Representation of a project may be +further defined and clarified by project maintainers. + +### Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at support@clidey.com. All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +### Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at [http://contributor-covenant.org/version/1/4][version] + +[homepage]: http://contributor-covenant.org +[version]: http://contributor-covenant.org/version/1/4/ diff --git a/core/graph/generated.go b/core/graph/generated.go index 04485aa..10adbb1 100644 --- a/core/graph/generated.go +++ b/core/graph/generated.go @@ -47,6 +47,12 @@ type DirectiveRoot struct { } type ComplexityRoot struct { + AIChatMessage struct { + Result func(childComplexity int) int + Text func(childComplexity int) int + Type func(childComplexity int) int + } + Column struct { Name func(childComplexity int) int Type func(childComplexity int) int @@ -71,6 +77,7 @@ type ComplexityRoot struct { Mutation struct { AddRow func(childComplexity int, typeArg model.DatabaseType, schema string, storageUnit string, values []*model.RecordInput) int AddStorageUnit func(childComplexity int, typeArg model.DatabaseType, schema string, storageUnit string, fields []*model.RecordInput) int + DeleteRow func(childComplexity int, typeArg model.DatabaseType, schema string, storageUnit string, values []*model.RecordInput) int Login func(childComplexity int, credentials model.LoginCredentials) int LoginWithProfile func(childComplexity int, profile model.LoginProfileInput) int Logout func(childComplexity int) int @@ -78,6 +85,8 @@ type ComplexityRoot struct { } Query struct { + AIChat func(childComplexity int, typeArg model.DatabaseType, schema string, input model.ChatInput) int + AIModel func(childComplexity int) int Database func(childComplexity int, typeArg model.DatabaseType) int Graph func(childComplexity int, typeArg model.DatabaseType, schema string) int Profiles func(childComplexity int) int @@ -115,6 +124,7 @@ type MutationResolver interface { AddStorageUnit(ctx context.Context, typeArg model.DatabaseType, schema string, storageUnit string, fields []*model.RecordInput) (*model.StatusResponse, error) UpdateStorageUnit(ctx context.Context, typeArg model.DatabaseType, schema string, storageUnit string, values []*model.RecordInput) (*model.StatusResponse, error) AddRow(ctx context.Context, typeArg model.DatabaseType, schema string, storageUnit string, values []*model.RecordInput) (*model.StatusResponse, error) + DeleteRow(ctx context.Context, typeArg model.DatabaseType, schema string, storageUnit string, values []*model.RecordInput) (*model.StatusResponse, error) } type QueryResolver interface { Profiles(ctx context.Context) ([]*model.LoginProfile, error) @@ -124,6 +134,8 @@ type QueryResolver interface { Row(ctx context.Context, typeArg model.DatabaseType, schema string, storageUnit string, where string, pageSize int, pageOffset int) (*model.RowsResult, error) RawExecute(ctx context.Context, typeArg model.DatabaseType, query string) (*model.RowsResult, error) Graph(ctx context.Context, typeArg model.DatabaseType, schema string) ([]*model.GraphUnit, error) + AIModel(ctx context.Context) ([]string, error) + AIChat(ctx context.Context, typeArg model.DatabaseType, schema string, input model.ChatInput) ([]*model.AIChatMessage, error) } type executableSchema struct { @@ -145,6 +157,27 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in _ = ec switch typeName + "." + field { + case "AIChatMessage.Result": + if e.complexity.AIChatMessage.Result == nil { + break + } + + return e.complexity.AIChatMessage.Result(childComplexity), true + + case "AIChatMessage.Text": + if e.complexity.AIChatMessage.Text == nil { + break + } + + return e.complexity.AIChatMessage.Text(childComplexity), true + + case "AIChatMessage.Type": + if e.complexity.AIChatMessage.Type == nil { + break + } + + return e.complexity.AIChatMessage.Type(childComplexity), true + case "Column.Name": if e.complexity.Column.Name == nil { break @@ -232,6 +265,18 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Mutation.AddStorageUnit(childComplexity, args["type"].(model.DatabaseType), args["schema"].(string), args["storageUnit"].(string), args["fields"].([]*model.RecordInput)), true + case "Mutation.DeleteRow": + if e.complexity.Mutation.DeleteRow == nil { + break + } + + args, err := ec.field_Mutation_DeleteRow_args(context.TODO(), rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Mutation.DeleteRow(childComplexity, args["type"].(model.DatabaseType), args["schema"].(string), args["storageUnit"].(string), args["values"].([]*model.RecordInput)), true + case "Mutation.Login": if e.complexity.Mutation.Login == nil { break @@ -275,6 +320,25 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Mutation.UpdateStorageUnit(childComplexity, args["type"].(model.DatabaseType), args["schema"].(string), args["storageUnit"].(string), args["values"].([]*model.RecordInput)), true + case "Query.AIChat": + if e.complexity.Query.AIChat == nil { + break + } + + args, err := ec.field_Query_AIChat_args(context.TODO(), rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Query.AIChat(childComplexity, args["type"].(model.DatabaseType), args["schema"].(string), args["input"].(model.ChatInput)), true + + case "Query.AIModel": + if e.complexity.Query.AIModel == nil { + break + } + + return e.complexity.Query.AIModel(childComplexity), true + case "Query.Database": if e.complexity.Query.Database == nil { break @@ -418,6 +482,7 @@ func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler { rc := graphql.GetOperationContext(ctx) ec := executionContext{rc, e, 0, 0, make(chan graphql.DeferredResult)} inputUnmarshalMap := graphql.BuildUnmarshalerMap( + ec.unmarshalInputChatInput, ec.unmarshalInputLoginCredentials, ec.unmarshalInputLoginProfileInput, ec.unmarshalInputRecordInput, @@ -621,6 +686,48 @@ func (ec *executionContext) field_Mutation_AddStorageUnit_args(ctx context.Conte return args, nil } +func (ec *executionContext) field_Mutation_DeleteRow_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 model.DatabaseType + if tmp, ok := rawArgs["type"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("type")) + arg0, err = ec.unmarshalNDatabaseType2githubᚗcomᚋclideyᚋwhodbᚋcoreᚋgraphᚋmodelᚐDatabaseType(ctx, tmp) + if err != nil { + return nil, err + } + } + args["type"] = arg0 + var arg1 string + if tmp, ok := rawArgs["schema"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("schema")) + arg1, err = ec.unmarshalNString2string(ctx, tmp) + if err != nil { + return nil, err + } + } + args["schema"] = arg1 + var arg2 string + if tmp, ok := rawArgs["storageUnit"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("storageUnit")) + arg2, err = ec.unmarshalNString2string(ctx, tmp) + if err != nil { + return nil, err + } + } + args["storageUnit"] = arg2 + var arg3 []*model.RecordInput + if tmp, ok := rawArgs["values"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("values")) + arg3, err = ec.unmarshalNRecordInput2ᚕᚖgithubᚗcomᚋclideyᚋwhodbᚋcoreᚋgraphᚋmodelᚐRecordInputᚄ(ctx, tmp) + if err != nil { + return nil, err + } + } + args["values"] = arg3 + return args, nil +} + func (ec *executionContext) field_Mutation_LoginWithProfile_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -693,6 +800,39 @@ func (ec *executionContext) field_Mutation_UpdateStorageUnit_args(ctx context.Co return args, nil } +func (ec *executionContext) field_Query_AIChat_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 model.DatabaseType + if tmp, ok := rawArgs["type"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("type")) + arg0, err = ec.unmarshalNDatabaseType2githubᚗcomᚋclideyᚋwhodbᚋcoreᚋgraphᚋmodelᚐDatabaseType(ctx, tmp) + if err != nil { + return nil, err + } + } + args["type"] = arg0 + var arg1 string + if tmp, ok := rawArgs["schema"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("schema")) + arg1, err = ec.unmarshalNString2string(ctx, tmp) + if err != nil { + return nil, err + } + } + args["schema"] = arg1 + var arg2 model.ChatInput + if tmp, ok := rawArgs["input"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("input")) + arg2, err = ec.unmarshalNChatInput2githubᚗcomᚋclideyᚋwhodbᚋcoreᚋgraphᚋmodelᚐChatInput(ctx, tmp) + if err != nil { + return nil, err + } + } + args["input"] = arg2 + return args, nil +} + func (ec *executionContext) field_Query_Database_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -908,6 +1048,143 @@ func (ec *executionContext) field___Type_fields_args(ctx context.Context, rawArg // region **************************** field.gotpl ***************************** +func (ec *executionContext) _AIChatMessage_Type(ctx context.Context, field graphql.CollectedField, obj *model.AIChatMessage) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_AIChatMessage_Type(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Type, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + fc.Result = res + return ec.marshalNString2string(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_AIChatMessage_Type(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "AIChatMessage", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _AIChatMessage_Result(ctx context.Context, field graphql.CollectedField, obj *model.AIChatMessage) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_AIChatMessage_Result(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Result, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*model.RowsResult) + fc.Result = res + return ec.marshalORowsResult2ᚖgithubᚗcomᚋclideyᚋwhodbᚋcoreᚋgraphᚋmodelᚐRowsResult(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_AIChatMessage_Result(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "AIChatMessage", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "Columns": + return ec.fieldContext_RowsResult_Columns(ctx, field) + case "Rows": + return ec.fieldContext_RowsResult_Rows(ctx, field) + case "DisableUpdate": + return ec.fieldContext_RowsResult_DisableUpdate(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type RowsResult", field.Name) + }, + } + return fc, nil +} + +func (ec *executionContext) _AIChatMessage_Text(ctx context.Context, field graphql.CollectedField, obj *model.AIChatMessage) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_AIChatMessage_Text(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Text, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + fc.Result = res + return ec.marshalNString2string(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_AIChatMessage_Text(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "AIChatMessage", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _Column_Type(ctx context.Context, field graphql.CollectedField, obj *model.Column) (ret graphql.Marshaler) { fc, err := ec.fieldContext_Column_Type(ctx, field) if err != nil { @@ -1656,6 +1933,65 @@ func (ec *executionContext) fieldContext_Mutation_AddRow(ctx context.Context, fi return fc, nil } +func (ec *executionContext) _Mutation_DeleteRow(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Mutation_DeleteRow(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Mutation().DeleteRow(rctx, fc.Args["type"].(model.DatabaseType), fc.Args["schema"].(string), fc.Args["storageUnit"].(string), fc.Args["values"].([]*model.RecordInput)) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(*model.StatusResponse) + fc.Result = res + return ec.marshalNStatusResponse2ᚖgithubᚗcomᚋclideyᚋwhodbᚋcoreᚋgraphᚋmodelᚐStatusResponse(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Mutation_DeleteRow(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Mutation", + Field: field, + IsMethod: true, + IsResolver: true, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "Status": + return ec.fieldContext_StatusResponse_Status(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type StatusResponse", field.Name) + }, + } + defer func() { + if r := recover(); r != nil { + err = ec.Recover(ctx, r) + ec.Error(ctx, err) + } + }() + ctx = graphql.WithFieldContext(ctx, fc) + if fc.Args, err = ec.field_Mutation_DeleteRow_args(ctx, field.ArgumentMap(ec.Variables)); err != nil { + ec.Error(ctx, err) + return fc, err + } + return fc, nil +} + func (ec *executionContext) _Query_Profiles(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { fc, err := ec.fieldContext_Query_Profiles(ctx, field) if err != nil { @@ -1846,10 +2182,134 @@ func (ec *executionContext) _Query_StorageUnit(ctx context.Context, field graphq } res := resTmp.([]*model.StorageUnit) fc.Result = res - return ec.marshalNStorageUnit2ᚕᚖgithubᚗcomᚋclideyᚋwhodbᚋcoreᚋgraphᚋmodelᚐStorageUnitᚄ(ctx, field.Selections, res) + return ec.marshalNStorageUnit2ᚕᚖgithubᚗcomᚋclideyᚋwhodbᚋcoreᚋgraphᚋmodelᚐStorageUnitᚄ(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Query_StorageUnit(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Query", + Field: field, + IsMethod: true, + IsResolver: true, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "Name": + return ec.fieldContext_StorageUnit_Name(ctx, field) + case "Attributes": + return ec.fieldContext_StorageUnit_Attributes(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type StorageUnit", field.Name) + }, + } + defer func() { + if r := recover(); r != nil { + err = ec.Recover(ctx, r) + ec.Error(ctx, err) + } + }() + ctx = graphql.WithFieldContext(ctx, fc) + if fc.Args, err = ec.field_Query_StorageUnit_args(ctx, field.ArgumentMap(ec.Variables)); err != nil { + ec.Error(ctx, err) + return fc, err + } + return fc, nil +} + +func (ec *executionContext) _Query_Row(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Query_Row(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Query().Row(rctx, fc.Args["type"].(model.DatabaseType), fc.Args["schema"].(string), fc.Args["storageUnit"].(string), fc.Args["where"].(string), fc.Args["pageSize"].(int), fc.Args["pageOffset"].(int)) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(*model.RowsResult) + fc.Result = res + return ec.marshalNRowsResult2ᚖgithubᚗcomᚋclideyᚋwhodbᚋcoreᚋgraphᚋmodelᚐRowsResult(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Query_Row(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Query", + Field: field, + IsMethod: true, + IsResolver: true, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "Columns": + return ec.fieldContext_RowsResult_Columns(ctx, field) + case "Rows": + return ec.fieldContext_RowsResult_Rows(ctx, field) + case "DisableUpdate": + return ec.fieldContext_RowsResult_DisableUpdate(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type RowsResult", field.Name) + }, + } + defer func() { + if r := recover(); r != nil { + err = ec.Recover(ctx, r) + ec.Error(ctx, err) + } + }() + ctx = graphql.WithFieldContext(ctx, fc) + if fc.Args, err = ec.field_Query_Row_args(ctx, field.ArgumentMap(ec.Variables)); err != nil { + ec.Error(ctx, err) + return fc, err + } + return fc, nil +} + +func (ec *executionContext) _Query_RawExecute(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Query_RawExecute(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Query().RawExecute(rctx, fc.Args["type"].(model.DatabaseType), fc.Args["query"].(string)) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(*model.RowsResult) + fc.Result = res + return ec.marshalNRowsResult2ᚖgithubᚗcomᚋclideyᚋwhodbᚋcoreᚋgraphᚋmodelᚐRowsResult(ctx, field.Selections, res) } -func (ec *executionContext) fieldContext_Query_StorageUnit(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { +func (ec *executionContext) fieldContext_Query_RawExecute(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { fc = &graphql.FieldContext{ Object: "Query", Field: field, @@ -1857,12 +2317,14 @@ func (ec *executionContext) fieldContext_Query_StorageUnit(ctx context.Context, IsResolver: true, Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { switch field.Name { - case "Name": - return ec.fieldContext_StorageUnit_Name(ctx, field) - case "Attributes": - return ec.fieldContext_StorageUnit_Attributes(ctx, field) + case "Columns": + return ec.fieldContext_RowsResult_Columns(ctx, field) + case "Rows": + return ec.fieldContext_RowsResult_Rows(ctx, field) + case "DisableUpdate": + return ec.fieldContext_RowsResult_DisableUpdate(ctx, field) } - return nil, fmt.Errorf("no field named %q was found under type StorageUnit", field.Name) + return nil, fmt.Errorf("no field named %q was found under type RowsResult", field.Name) }, } defer func() { @@ -1872,15 +2334,15 @@ func (ec *executionContext) fieldContext_Query_StorageUnit(ctx context.Context, } }() ctx = graphql.WithFieldContext(ctx, fc) - if fc.Args, err = ec.field_Query_StorageUnit_args(ctx, field.ArgumentMap(ec.Variables)); err != nil { + if fc.Args, err = ec.field_Query_RawExecute_args(ctx, field.ArgumentMap(ec.Variables)); err != nil { ec.Error(ctx, err) return fc, err } return fc, nil } -func (ec *executionContext) _Query_Row(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { - fc, err := ec.fieldContext_Query_Row(ctx, field) +func (ec *executionContext) _Query_Graph(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Query_Graph(ctx, field) if err != nil { return graphql.Null } @@ -1893,7 +2355,7 @@ func (ec *executionContext) _Query_Row(ctx context.Context, field graphql.Collec }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.Query().Row(rctx, fc.Args["type"].(model.DatabaseType), fc.Args["schema"].(string), fc.Args["storageUnit"].(string), fc.Args["where"].(string), fc.Args["pageSize"].(int), fc.Args["pageOffset"].(int)) + return ec.resolvers.Query().Graph(rctx, fc.Args["type"].(model.DatabaseType), fc.Args["schema"].(string)) }) if err != nil { ec.Error(ctx, err) @@ -1905,12 +2367,12 @@ func (ec *executionContext) _Query_Row(ctx context.Context, field graphql.Collec } return graphql.Null } - res := resTmp.(*model.RowsResult) + res := resTmp.([]*model.GraphUnit) fc.Result = res - return ec.marshalNRowsResult2ᚖgithubᚗcomᚋclideyᚋwhodbᚋcoreᚋgraphᚋmodelᚐRowsResult(ctx, field.Selections, res) + return ec.marshalNGraphUnit2ᚕᚖgithubᚗcomᚋclideyᚋwhodbᚋcoreᚋgraphᚋmodelᚐGraphUnitᚄ(ctx, field.Selections, res) } -func (ec *executionContext) fieldContext_Query_Row(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { +func (ec *executionContext) fieldContext_Query_Graph(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { fc = &graphql.FieldContext{ Object: "Query", Field: field, @@ -1918,14 +2380,12 @@ func (ec *executionContext) fieldContext_Query_Row(ctx context.Context, field gr IsResolver: true, Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { switch field.Name { - case "Columns": - return ec.fieldContext_RowsResult_Columns(ctx, field) - case "Rows": - return ec.fieldContext_RowsResult_Rows(ctx, field) - case "DisableUpdate": - return ec.fieldContext_RowsResult_DisableUpdate(ctx, field) + case "Unit": + return ec.fieldContext_GraphUnit_Unit(ctx, field) + case "Relations": + return ec.fieldContext_GraphUnit_Relations(ctx, field) } - return nil, fmt.Errorf("no field named %q was found under type RowsResult", field.Name) + return nil, fmt.Errorf("no field named %q was found under type GraphUnit", field.Name) }, } defer func() { @@ -1935,15 +2395,15 @@ func (ec *executionContext) fieldContext_Query_Row(ctx context.Context, field gr } }() ctx = graphql.WithFieldContext(ctx, fc) - if fc.Args, err = ec.field_Query_Row_args(ctx, field.ArgumentMap(ec.Variables)); err != nil { + if fc.Args, err = ec.field_Query_Graph_args(ctx, field.ArgumentMap(ec.Variables)); err != nil { ec.Error(ctx, err) return fc, err } return fc, nil } -func (ec *executionContext) _Query_RawExecute(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { - fc, err := ec.fieldContext_Query_RawExecute(ctx, field) +func (ec *executionContext) _Query_AIModel(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Query_AIModel(ctx, field) if err != nil { return graphql.Null } @@ -1956,7 +2416,7 @@ func (ec *executionContext) _Query_RawExecute(ctx context.Context, field graphql }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.Query().RawExecute(rctx, fc.Args["type"].(model.DatabaseType), fc.Args["query"].(string)) + return ec.resolvers.Query().AIModel(rctx) }) if err != nil { ec.Error(ctx, err) @@ -1968,45 +2428,26 @@ func (ec *executionContext) _Query_RawExecute(ctx context.Context, field graphql } return graphql.Null } - res := resTmp.(*model.RowsResult) + res := resTmp.([]string) fc.Result = res - return ec.marshalNRowsResult2ᚖgithubᚗcomᚋclideyᚋwhodbᚋcoreᚋgraphᚋmodelᚐRowsResult(ctx, field.Selections, res) + return ec.marshalNString2ᚕstringᚄ(ctx, field.Selections, res) } -func (ec *executionContext) fieldContext_Query_RawExecute(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { +func (ec *executionContext) fieldContext_Query_AIModel(_ context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { fc = &graphql.FieldContext{ Object: "Query", Field: field, IsMethod: true, IsResolver: true, Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { - switch field.Name { - case "Columns": - return ec.fieldContext_RowsResult_Columns(ctx, field) - case "Rows": - return ec.fieldContext_RowsResult_Rows(ctx, field) - case "DisableUpdate": - return ec.fieldContext_RowsResult_DisableUpdate(ctx, field) - } - return nil, fmt.Errorf("no field named %q was found under type RowsResult", field.Name) + return nil, errors.New("field of type String does not have child fields") }, } - defer func() { - if r := recover(); r != nil { - err = ec.Recover(ctx, r) - ec.Error(ctx, err) - } - }() - ctx = graphql.WithFieldContext(ctx, fc) - if fc.Args, err = ec.field_Query_RawExecute_args(ctx, field.ArgumentMap(ec.Variables)); err != nil { - ec.Error(ctx, err) - return fc, err - } return fc, nil } -func (ec *executionContext) _Query_Graph(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { - fc, err := ec.fieldContext_Query_Graph(ctx, field) +func (ec *executionContext) _Query_AIChat(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Query_AIChat(ctx, field) if err != nil { return graphql.Null } @@ -2019,7 +2460,7 @@ func (ec *executionContext) _Query_Graph(ctx context.Context, field graphql.Coll }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.Query().Graph(rctx, fc.Args["type"].(model.DatabaseType), fc.Args["schema"].(string)) + return ec.resolvers.Query().AIChat(rctx, fc.Args["type"].(model.DatabaseType), fc.Args["schema"].(string), fc.Args["input"].(model.ChatInput)) }) if err != nil { ec.Error(ctx, err) @@ -2031,12 +2472,12 @@ func (ec *executionContext) _Query_Graph(ctx context.Context, field graphql.Coll } return graphql.Null } - res := resTmp.([]*model.GraphUnit) + res := resTmp.([]*model.AIChatMessage) fc.Result = res - return ec.marshalNGraphUnit2ᚕᚖgithubᚗcomᚋclideyᚋwhodbᚋcoreᚋgraphᚋmodelᚐGraphUnitᚄ(ctx, field.Selections, res) + return ec.marshalNAIChatMessage2ᚕᚖgithubᚗcomᚋclideyᚋwhodbᚋcoreᚋgraphᚋmodelᚐAIChatMessageᚄ(ctx, field.Selections, res) } -func (ec *executionContext) fieldContext_Query_Graph(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { +func (ec *executionContext) fieldContext_Query_AIChat(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { fc = &graphql.FieldContext{ Object: "Query", Field: field, @@ -2044,12 +2485,14 @@ func (ec *executionContext) fieldContext_Query_Graph(ctx context.Context, field IsResolver: true, Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { switch field.Name { - case "Unit": - return ec.fieldContext_GraphUnit_Unit(ctx, field) - case "Relations": - return ec.fieldContext_GraphUnit_Relations(ctx, field) + case "Type": + return ec.fieldContext_AIChatMessage_Type(ctx, field) + case "Result": + return ec.fieldContext_AIChatMessage_Result(ctx, field) + case "Text": + return ec.fieldContext_AIChatMessage_Text(ctx, field) } - return nil, fmt.Errorf("no field named %q was found under type GraphUnit", field.Name) + return nil, fmt.Errorf("no field named %q was found under type AIChatMessage", field.Name) }, } defer func() { @@ -2059,7 +2502,7 @@ func (ec *executionContext) fieldContext_Query_Graph(ctx context.Context, field } }() ctx = graphql.WithFieldContext(ctx, fc) - if fc.Args, err = ec.field_Query_Graph_args(ctx, field.ArgumentMap(ec.Variables)); err != nil { + if fc.Args, err = ec.field_Query_AIChat_args(ctx, field.ArgumentMap(ec.Variables)); err != nil { ec.Error(ctx, err) return fc, err } @@ -4332,6 +4775,47 @@ func (ec *executionContext) fieldContext___Type_specifiedByURL(_ context.Context // region **************************** input.gotpl ***************************** +func (ec *executionContext) unmarshalInputChatInput(ctx context.Context, obj interface{}) (model.ChatInput, error) { + var it model.ChatInput + asMap := map[string]interface{}{} + for k, v := range obj.(map[string]interface{}) { + asMap[k] = v + } + + fieldsInOrder := [...]string{"PreviousConversation", "Query", "Model"} + for _, k := range fieldsInOrder { + v, ok := asMap[k] + if !ok { + continue + } + switch k { + case "PreviousConversation": + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("PreviousConversation")) + data, err := ec.unmarshalNString2string(ctx, v) + if err != nil { + return it, err + } + it.PreviousConversation = data + case "Query": + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("Query")) + data, err := ec.unmarshalNString2string(ctx, v) + if err != nil { + return it, err + } + it.Query = data + case "Model": + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("Model")) + data, err := ec.unmarshalNString2string(ctx, v) + if err != nil { + return it, err + } + it.Model = data + } + } + + return it, nil +} + func (ec *executionContext) unmarshalInputLoginCredentials(ctx context.Context, obj interface{}) (model.LoginCredentials, error) { var it model.LoginCredentials asMap := map[string]interface{}{} @@ -4491,6 +4975,52 @@ func (ec *executionContext) unmarshalInputRecordInput(ctx context.Context, obj i // region **************************** object.gotpl **************************** +var aIChatMessageImplementors = []string{"AIChatMessage"} + +func (ec *executionContext) _AIChatMessage(ctx context.Context, sel ast.SelectionSet, obj *model.AIChatMessage) graphql.Marshaler { + fields := graphql.CollectFields(ec.OperationContext, sel, aIChatMessageImplementors) + + out := graphql.NewFieldSet(fields) + deferred := make(map[string]*graphql.FieldSet) + for i, field := range fields { + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("AIChatMessage") + case "Type": + out.Values[i] = ec._AIChatMessage_Type(ctx, field, obj) + if out.Values[i] == graphql.Null { + out.Invalids++ + } + case "Result": + out.Values[i] = ec._AIChatMessage_Result(ctx, field, obj) + case "Text": + out.Values[i] = ec._AIChatMessage_Text(ctx, field, obj) + if out.Values[i] == graphql.Null { + out.Invalids++ + } + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + out.Dispatch(ctx) + if out.Invalids > 0 { + return graphql.Null + } + + atomic.AddInt32(&ec.deferred, int32(len(deferred))) + + for label, dfs := range deferred { + ec.processDeferredGroup(graphql.DeferredGroup{ + Label: label, + Path: graphql.GetPath(ctx), + FieldSet: dfs, + Context: ctx, + }) + } + + return out +} + var columnImplementors = []string{"Column"} func (ec *executionContext) _Column(ctx context.Context, sel ast.SelectionSet, obj *model.Column) graphql.Marshaler { @@ -4730,6 +5260,13 @@ func (ec *executionContext) _Mutation(ctx context.Context, sel ast.SelectionSet) if out.Values[i] == graphql.Null { out.Invalids++ } + case "DeleteRow": + out.Values[i] = ec.OperationContext.RootResolverMiddleware(innerCtx, func(ctx context.Context) (res graphql.Marshaler) { + return ec._Mutation_DeleteRow(ctx, field) + }) + if out.Values[i] == graphql.Null { + out.Invalids++ + } default: panic("unknown field " + strconv.Quote(field.Name)) } @@ -4925,6 +5462,50 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) } + out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return rrm(innerCtx) }) + case "AIModel": + field := field + + innerFunc := func(ctx context.Context, fs *graphql.FieldSet) (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._Query_AIModel(ctx, field) + if res == graphql.Null { + atomic.AddUint32(&fs.Invalids, 1) + } + return res + } + + rrm := func(ctx context.Context) graphql.Marshaler { + return ec.OperationContext.RootResolverMiddleware(ctx, + func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) + } + + out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return rrm(innerCtx) }) + case "AIChat": + field := field + + innerFunc := func(ctx context.Context, fs *graphql.FieldSet) (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._Query_AIChat(ctx, field) + if res == graphql.Null { + atomic.AddUint32(&fs.Invalids, 1) + } + return res + } + + rrm := func(ctx context.Context) graphql.Marshaler { + return ec.OperationContext.RootResolverMiddleware(ctx, + func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) + } + out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return rrm(innerCtx) }) case "__type": out.Values[i] = ec.OperationContext.RootResolverMiddleware(innerCtx, func(ctx context.Context) (res graphql.Marshaler) { @@ -5459,6 +6040,60 @@ func (ec *executionContext) ___Type(ctx context.Context, sel ast.SelectionSet, o // region ***************************** type.gotpl ***************************** +func (ec *executionContext) marshalNAIChatMessage2ᚕᚖgithubᚗcomᚋclideyᚋwhodbᚋcoreᚋgraphᚋmodelᚐAIChatMessageᚄ(ctx context.Context, sel ast.SelectionSet, v []*model.AIChatMessage) graphql.Marshaler { + ret := make(graphql.Array, len(v)) + var wg sync.WaitGroup + isLen1 := len(v) == 1 + if !isLen1 { + wg.Add(len(v)) + } + for i := range v { + i := i + fc := &graphql.FieldContext{ + Index: &i, + Result: &v[i], + } + ctx := graphql.WithFieldContext(ctx, fc) + f := func(i int) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = nil + } + }() + if !isLen1 { + defer wg.Done() + } + ret[i] = ec.marshalNAIChatMessage2ᚖgithubᚗcomᚋclideyᚋwhodbᚋcoreᚋgraphᚋmodelᚐAIChatMessage(ctx, sel, v[i]) + } + if isLen1 { + f(i) + } else { + go f(i) + } + + } + wg.Wait() + + for _, e := range ret { + if e == graphql.Null { + return graphql.Null + } + } + + return ret +} + +func (ec *executionContext) marshalNAIChatMessage2ᚖgithubᚗcomᚋclideyᚋwhodbᚋcoreᚋgraphᚋmodelᚐAIChatMessage(ctx context.Context, sel ast.SelectionSet, v *model.AIChatMessage) graphql.Marshaler { + if v == nil { + if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { + ec.Errorf(ctx, "the requested element is null which the schema does not allow") + } + return graphql.Null + } + return ec._AIChatMessage(ctx, sel, v) +} + func (ec *executionContext) unmarshalNBoolean2bool(ctx context.Context, v interface{}) (bool, error) { res, err := graphql.UnmarshalBoolean(v) return res, graphql.ErrorOnPath(ctx, err) @@ -5474,6 +6109,11 @@ func (ec *executionContext) marshalNBoolean2bool(ctx context.Context, sel ast.Se return res } +func (ec *executionContext) unmarshalNChatInput2githubᚗcomᚋclideyᚋwhodbᚋcoreᚋgraphᚋmodelᚐChatInput(ctx context.Context, v interface{}) (model.ChatInput, error) { + res, err := ec.unmarshalInputChatInput(ctx, v) + return res, graphql.ErrorOnPath(ctx, err) +} + func (ec *executionContext) marshalNColumn2ᚕᚖgithubᚗcomᚋclideyᚋwhodbᚋcoreᚋgraphᚋmodelᚐColumnᚄ(ctx context.Context, sel ast.SelectionSet, v []*model.Column) graphql.Marshaler { ret := make(graphql.Array, len(v)) var wg sync.WaitGroup @@ -6271,6 +6911,13 @@ func (ec *executionContext) unmarshalORecordInput2ᚕᚖgithubᚗcomᚋclideyᚋ return res, nil } +func (ec *executionContext) marshalORowsResult2ᚖgithubᚗcomᚋclideyᚋwhodbᚋcoreᚋgraphᚋmodelᚐRowsResult(ctx context.Context, sel ast.SelectionSet, v *model.RowsResult) graphql.Marshaler { + if v == nil { + return graphql.Null + } + return ec._RowsResult(ctx, sel, v) +} + func (ec *executionContext) unmarshalOString2ᚖstring(ctx context.Context, v interface{}) (*string, error) { if v == nil { return nil, nil diff --git a/core/graph/model/models_gen.go b/core/graph/model/models_gen.go index ad3aeeb..8d9422b 100644 --- a/core/graph/model/models_gen.go +++ b/core/graph/model/models_gen.go @@ -8,6 +8,18 @@ import ( "strconv" ) +type AIChatMessage struct { + Type string `json:"Type"` + Result *RowsResult `json:"Result,omitempty"` + Text string `json:"Text"` +} + +type ChatInput struct { + PreviousConversation string `json:"PreviousConversation"` + Query string `json:"Query"` + Model string `json:"Model"` +} + type Column struct { Type string `json:"Type"` Name string `json:"Name"` diff --git a/core/graph/schema.graphqls b/core/graph/schema.graphqls index 07dfc12..c84ee16 100644 --- a/core/graph/schema.graphqls +++ b/core/graph/schema.graphqls @@ -83,6 +83,17 @@ type StatusResponse { Status: Boolean! } +input ChatInput { + PreviousConversation: String! + Query: String! + Model: String! +} + +type AIChatMessage { + Type: String! + Result: RowsResult + Text: String! +} type Query { Profiles: [LoginProfile!]! @@ -92,6 +103,8 @@ type Query { Row(type: DatabaseType!, schema: String!, storageUnit: String!, where: String!, pageSize: Int!, pageOffset: Int!): RowsResult! # row, document RawExecute(type: DatabaseType!, query: String!): RowsResult! Graph(type: DatabaseType!, schema: String!): [GraphUnit!]! + AIModel: [String!]! + AIChat(type: DatabaseType!, schema: String!, input: ChatInput!): [AIChatMessage!]! } type Mutation { @@ -102,4 +115,5 @@ type Mutation { AddStorageUnit(type: DatabaseType!, schema: String!, storageUnit: String!, fields: [RecordInput!]!): StatusResponse! UpdateStorageUnit(type: DatabaseType!, schema: String!, storageUnit: String!, values: [RecordInput!]!): StatusResponse! AddRow(type: DatabaseType!, schema: String!, storageUnit: String!, values: [RecordInput!]!): StatusResponse! + DeleteRow(type: DatabaseType!, schema: String!, storageUnit: String!, values: [RecordInput!]!): StatusResponse! } \ No newline at end of file diff --git a/core/graph/schema.resolvers.go b/core/graph/schema.resolvers.go index 79cf3f4..3c6ddac 100644 --- a/core/graph/schema.resolvers.go +++ b/core/graph/schema.resolvers.go @@ -12,6 +12,7 @@ import ( "github.com/clidey/whodb/core/src" "github.com/clidey/whodb/core/src/auth" "github.com/clidey/whodb/core/src/engine" + "github.com/clidey/whodb/core/src/llm" ) // Login is the resolver for the Login field. @@ -113,6 +114,23 @@ func (r *mutationResolver) AddRow(ctx context.Context, typeArg model.DatabaseTyp }) } status, err := src.MainEngine.Choose(engine.DatabaseType(typeArg)).AddRow(config, schema, storageUnit, valuesRecords) + + if err != nil { + return nil, err + } + return &model.StatusResponse{ + Status: status, + }, nil +} + +// DeleteRow is the resolver for the DeleteRow field. +func (r *mutationResolver) DeleteRow(ctx context.Context, typeArg model.DatabaseType, schema string, storageUnit string, values []*model.RecordInput) (*model.StatusResponse, error) { + config := engine.NewPluginConfig(auth.GetCredentials(ctx)) + valuesMap := map[string]string{} + for _, value := range values { + valuesMap[value.Key] = value.Value + } + status, err := src.MainEngine.Choose(engine.DatabaseType(typeArg)).DeleteRow(config, schema, storageUnit, valuesMap) if err != nil { return nil, err } @@ -225,6 +243,51 @@ func (r *queryResolver) Graph(ctx context.Context, typeArg model.DatabaseType, s return graphUnitsModel, nil } +// AIModel is the resolver for the AIModel field. +func (r *queryResolver) AIModel(ctx context.Context) ([]string, error) { + models, err := llm.Instance(llm.Ollama_LLMType).GetSupportedModels() + if err != nil { + return nil, err + } + return models, nil +} + +// AIChat is the resolver for the AIChat field. +func (r *queryResolver) AIChat(ctx context.Context, typeArg model.DatabaseType, schema string, input model.ChatInput) ([]*model.AIChatMessage, error) { + config := engine.NewPluginConfig(auth.GetCredentials(ctx)) + messages, err := src.MainEngine.Choose(engine.DatabaseType(typeArg)).Chat(config, schema, input.Model, input.PreviousConversation, input.Query) + + if err != nil { + return nil, err + } + + chatResponse := []*model.AIChatMessage{} + + for _, message := range messages { + var result *model.RowsResult + if message.Type == "sql" { + columns := []*model.Column{} + for _, column := range message.Result.Columns { + columns = append(columns, &model.Column{ + Type: column.Type, + Name: column.Name, + }) + } + result = &model.RowsResult{ + Columns: columns, + Rows: message.Result.Rows, + } + } + chatResponse = append(chatResponse, &model.AIChatMessage{ + Type: message.Type, + Result: result, + Text: message.Text, + }) + } + + return chatResponse, nil +} + // Mutation returns MutationResolver implementation. func (r *Resolver) Mutation() MutationResolver { return &mutationResolver{r} } diff --git a/core/src/common/chat.go b/core/src/common/chat.go new file mode 100644 index 0000000..eb91d82 --- /dev/null +++ b/core/src/common/chat.go @@ -0,0 +1,19 @@ +package common + +const RawSQLQueryPrompt = `You are a %v SQL query expert. You have access to the following information: +Schema: %v +Tables and Fields: +%v +Instructions: +Based on the user's input, generate a explanation response with a valid SQL query that will retrieve the required data or execute an action from the database. + +Previous Conversation: +%v + +User Prompt: +%v + +System Prompt: +Generate the SQL query inside ` + "```sql" + ` that corresponds to the user's request. Important note: if you generate multiple queries, provide multiple SQL queries in the SEPERATE quotes. +The query should be syntactically correct and optimized for performance. Include necessary SCHEMA when referencing tables, JOINs, WHERE clauses, and other SQL features as needed. +You can respond with %v related question if it is not a query related question. Speak to the user as "you".` diff --git a/core/src/common/utils.go b/core/src/common/utils.go index 77d2231..8867622 100644 --- a/core/src/common/utils.go +++ b/core/src/common/utils.go @@ -1,6 +1,12 @@ package common -import "github.com/clidey/whodb/core/src/engine" +import ( + "fmt" + "regexp" + "strings" + + "github.com/clidey/whodb/core/src/engine" +) func ContainsString(slice []string, element string) bool { for _, item := range slice { @@ -19,3 +25,53 @@ func GetRecordValueOrDefault(records []engine.Record, key string, defaultValue s } return defaultValue } + +type ExtractedText struct { + Type string + Text string +} + +func ExtractCodeFromResponse(response string) []ExtractedText { + tripleBacktickPattern := regexp.MustCompile("(?s)```(sql)?(.*?)```") + + codeBlocks := tripleBacktickPattern.FindAllStringSubmatchIndex(response, -1) + + var result []ExtractedText + var lastIndex int + + for _, loc := range codeBlocks { + start, end := loc[0], loc[1] + codeTypeStart, codeTypeEnd, contentStart, contentEnd := loc[2], loc[3], loc[4], loc[5] + + codeContent := response[contentStart:contentEnd] + + codeType := "sql" + if codeTypeStart != -1 && codeTypeEnd != -1 { + codeType = response[codeTypeStart:codeTypeEnd] + } + + if start > lastIndex { + result = append(result, ExtractedText{Type: "message", Text: response[lastIndex:start]}) + } + + result = append(result, ExtractedText{Type: codeType, Text: codeContent}) + + lastIndex = end + } + + if lastIndex < len(response) { + result = append(result, ExtractedText{Type: "message", Text: response[lastIndex:]}) + } + + return result +} + +func JoinWithQuotes(arr []string) string { + quotedStrings := make([]string, len(arr)) + + for i, str := range arr { + quotedStrings[i] = fmt.Sprintf("\"%s\"", str) + } + + return strings.Join(quotedStrings, ", ") +} diff --git a/core/src/engine/plugin.go b/core/src/engine/plugin.go index f96cc9e..5af4379 100644 --- a/core/src/engine/plugin.go +++ b/core/src/engine/plugin.go @@ -55,6 +55,12 @@ type GraphUnit struct { Relations []GraphUnitRelationship } +type ChatMessage struct { + Type string + Result *GetRowsResult + Text string +} + type PluginFunctions interface { GetDatabases(config *PluginConfig) ([]string, error) IsAvailable(config *PluginConfig) bool @@ -63,9 +69,11 @@ type PluginFunctions interface { AddStorageUnit(config *PluginConfig, schema string, storageUnit string, fields map[string]string) (bool, error) UpdateStorageUnit(config *PluginConfig, schema string, storageUnit string, values map[string]string) (bool, error) AddRow(config *PluginConfig, schema string, storageUnit string, values []Record) (bool, error) + DeleteRow(config *PluginConfig, schema string, storageUnit string, values map[string]string) (bool, error) GetRows(config *PluginConfig, schema string, storageUnit string, where string, pageSize int, pageOffset int) (*GetRowsResult, error) GetGraph(config *PluginConfig, schema string) ([]GraphUnit, error) RawExecute(config *PluginConfig, query string) (*GetRowsResult, error) + Chat(config *PluginConfig, schema string, model string, previousConversation string, query string) ([]*ChatMessage, error) } type Plugin struct { diff --git a/core/src/llm/llm.go b/core/src/llm/llm.go new file mode 100644 index 0000000..cf8cfb7 --- /dev/null +++ b/core/src/llm/llm.go @@ -0,0 +1,145 @@ +package llm + +import ( + "bufio" + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" +) + +type completionRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` +} + +type completionResponse struct { + Model string `json:"model"` + CreatedAt string `json:"created_at"` + Response string `json:"response"` + Done bool `json:"done"` +} + +const ollamaLocalEndpoint = "http://localhost:11434/api" + +type LLMType string + +const ( + Ollama_LLMType LLMType = "Ollama" +) + +type LLMModel string + +const ( + Llama3_LLMModel LLMModel = "Llama3" +) + +type LLMClient struct { + Type LLMType +} + +func (c *LLMClient) Complete(prompt string, model LLMModel, receiverChan *chan string) (*string, error) { + requestBody, err := json.Marshal(completionRequest{ + Model: string(model), + Prompt: prompt, + }) + + if err != nil { + return nil, err + } + + var url string + switch c.Type { + case Ollama_LLMType: + url = fmt.Sprintf("%v/generate", ollamaLocalEndpoint) + } + + resp, err := http.Post(url, "application/json", bytes.NewBuffer(requestBody)) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, errors.New(string(body)) + } + + responseBuilder := strings.Builder{} + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + var completionResponse completionResponse + err := json.Unmarshal([]byte(line), &completionResponse) + if err != nil { + return nil, err + } + if receiverChan != nil { + *receiverChan <- completionResponse.Response + } + if _, err := responseBuilder.WriteString(completionResponse.Response); err != nil { + return nil, err + } + if completionResponse.Done { + response := responseBuilder.String() + return &response, nil + } + } + + return nil, scanner.Err() +} + +func (c *LLMClient) GetSupportedModels() ([]string, error) { + var url string + switch c.Type { + case Ollama_LLMType: + url = fmt.Sprintf("%v/tags", ollamaLocalEndpoint) + } + + resp, err := http.Get(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, errors.New(string(body)) + } + + var modelsResp struct { + Models []struct { + Name string `json:"model"` + } `json:"models"` + } + if err := json.NewDecoder(resp.Body).Decode(&modelsResp); err != nil { + return nil, err + } + + models := []string{} + for _, model := range modelsResp.Models { + models = append(models, model.Name) + } + + return models, nil +} + +var llmInstance map[LLMType]*LLMClient + +func Instance(llmType LLMType) *LLMClient { + if llmInstance == nil { + llmInstance = make(map[LLMType]*LLMClient) + } + + if _, ok := llmInstance[llmType]; ok { + return llmInstance[llmType] + } + instance := &LLMClient{ + Type: llmType, + } + llmInstance[llmType] = instance + return instance +} diff --git a/core/src/plugins/elasticsearch/delete.go b/core/src/plugins/elasticsearch/delete.go new file mode 100644 index 0000000..00d4b35 --- /dev/null +++ b/core/src/plugins/elasticsearch/delete.go @@ -0,0 +1,65 @@ +package elasticsearch + +import ( + "context" + "encoding/json" + "errors" + "fmt" + + "github.com/clidey/whodb/core/src/engine" +) + +func (p *ElasticSearchPlugin) DeleteRow(config *engine.PluginConfig, database string, storageUnit string, values map[string]string) (bool, error) { + client, err := DB(config) + if err != nil { + return false, err + } + + // Extract the document JSON + documentJSON, ok := values["document"] + if !ok { + return false, errors.New("missing 'document' key in values map") + } + + // Unmarshal the JSON to extract the _id field + var jsonValues map[string]interface{} + if err := json.Unmarshal([]byte(documentJSON), &jsonValues); err != nil { + return false, err + } + + // Get the _id from the document + id, ok := jsonValues["_id"] + if !ok { + return false, errors.New("missing '_id' field in the document") + } + + // Delete the document by ID + res, err := client.Delete( + storageUnit, + id.(string), + client.Delete.WithContext(context.Background()), + client.Delete.WithRefresh("true"), // Ensure the deletion is immediately visible + ) + if err != nil { + return false, fmt.Errorf("failed to execute delete: %w", err) + } + defer res.Body.Close() + + // Check if the response indicates an error + if res.IsError() { + return false, fmt.Errorf("error deleting document: %s", res.String()) + } + + // Decode the response to check the result + var deleteResponse map[string]interface{} + if err := json.NewDecoder(res.Body).Decode(&deleteResponse); err != nil { + return false, err + } + + // Check if the deletion was successful + if result, ok := deleteResponse["result"].(string); ok && result != "deleted" { + return false, errors.New("no documents were deleted") + } + + return true, nil +} diff --git a/core/src/plugins/elasticsearch/elasticsearch.go b/core/src/plugins/elasticsearch/elasticsearch.go index 593cef4..c1ce758 100644 --- a/core/src/plugins/elasticsearch/elasticsearch.go +++ b/core/src/plugins/elasticsearch/elasticsearch.go @@ -149,6 +149,10 @@ func (p *ElasticSearchPlugin) RawExecute(config *engine.PluginConfig, query stri return nil, errors.New("unsupported operation") } +func (p *ElasticSearchPlugin) Chat(config *engine.PluginConfig, schema string, model string, previousConversation string, query string) ([]*engine.ChatMessage, error) { + return nil, errors.ErrUnsupported +} + func NewElasticSearchPlugin() *engine.Plugin { return &engine.Plugin{ Type: engine.DatabaseType_ElasticSearch, diff --git a/core/src/plugins/mongodb/delete.go b/core/src/plugins/mongodb/delete.go new file mode 100644 index 0000000..b1c7d05 --- /dev/null +++ b/core/src/plugins/mongodb/delete.go @@ -0,0 +1,57 @@ +package mongodb + +import ( + "context" + "encoding/json" + "errors" + "github.com/clidey/whodb/core/src/engine" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" +) + +func (p *MongoDBPlugin) DeleteRow(config *engine.PluginConfig, database string, storageUnit string, values map[string]string) (bool, error) { + ctx := context.Background() + client, err := DB(config) + if err != nil { + return false, err + } + defer client.Disconnect(ctx) + + db := client.Database(database) + collection := db.Collection(storageUnit) + + documentJSON, ok := values["document"] + if !ok { + return false, errors.New("missing 'document' key in values map") + } + + var jsonValues bson.M + if err := json.Unmarshal([]byte(documentJSON), &jsonValues); err != nil { + return false, err + } + + id, ok := jsonValues["_id"] + if !ok { + return false, errors.New("missing '_id' field in the document") + } + + objectID, err := primitive.ObjectIDFromHex(id.(string)) + if err != nil { + return false, errors.New("invalid '_id' field; not a valid ObjectID") + } + + delete(jsonValues, "_id") + + filter := bson.M{"_id": objectID} + + result, err := collection.DeleteOne(ctx, filter) + if err != nil { + return false, err + } + + if result.DeletedCount == 0 { + return false, errors.New("no documents were deleted") + } + + return true, nil +} diff --git a/core/src/plugins/mongodb/mongodb.go b/core/src/plugins/mongodb/mongodb.go index cc54063..fd34b46 100644 --- a/core/src/plugins/mongodb/mongodb.go +++ b/core/src/plugins/mongodb/mongodb.go @@ -131,6 +131,10 @@ func (p *MongoDBPlugin) RawExecute(config *engine.PluginConfig, query string) (* return nil, errors.ErrUnsupported } +func (p *MongoDBPlugin) Chat(config *engine.PluginConfig, schema string, model string, previousConversation string, query string) ([]*engine.ChatMessage, error) { + return nil, errors.ErrUnsupported +} + func NewMongoDBPlugin() *engine.Plugin { return &engine.Plugin{ Type: engine.DatabaseType_MongoDB, diff --git a/core/src/plugins/mysql/chat.go b/core/src/plugins/mysql/chat.go new file mode 100644 index 0000000..c520257 --- /dev/null +++ b/core/src/plugins/mysql/chat.go @@ -0,0 +1,66 @@ +package mysql + +import ( + "fmt" + "strings" + + "github.com/clidey/whodb/core/src/common" + "github.com/clidey/whodb/core/src/engine" + "github.com/clidey/whodb/core/src/llm" +) + +func (p *MySQLPlugin) Chat(config *engine.PluginConfig, schema string, model string, previousConversation string, query string) ([]*engine.ChatMessage, error) { + db, err := DB(config) + if err != nil { + return nil, err + } + sqlDb, err := db.DB() + if err != nil { + return nil, err + } + defer sqlDb.Close() + + tableFields, err := getTableSchema(db, schema) + if err != nil { + return nil, err + } + + tableDetails := strings.Builder{} + for tableName, fields := range tableFields { + tableDetails.WriteString(fmt.Sprintf("table: %v\n", tableName)) + for _, field := range fields { + tableDetails.WriteString(fmt.Sprintf("- %v (%v)\n", field.Key, field.Value)) + } + } + + context := tableDetails.String() + + completeQuery := fmt.Sprintf(common.RawSQLQueryPrompt, "MySQL", schema, context, previousConversation, query, "MySQL") + + response, err := llm.Instance(llm.Ollama_LLMType).Complete(completeQuery, llm.LLMModel(model), nil) + if err != nil { + return nil, err + } + + chats := common.ExtractCodeFromResponse(*response) + chatMessages := []*engine.ChatMessage{} + for _, chat := range chats { + var result *engine.GetRowsResult + chatType := "message" + if chat.Type == "sql" { + rowResult, err := p.RawExecute(config, chat.Text) + if err != nil { + return nil, err + } + chatType = "sql" + result = rowResult + } + chatMessages = append(chatMessages, &engine.ChatMessage{ + Type: chatType, + Result: result, + Text: chat.Text, + }) + } + + return chatMessages, nil +} diff --git a/core/src/plugins/mysql/delete.go b/core/src/plugins/mysql/delete.go new file mode 100644 index 0000000..f4e0e41 --- /dev/null +++ b/core/src/plugins/mysql/delete.go @@ -0,0 +1,69 @@ +package mysql + +import ( + "errors" + "fmt" + + "github.com/clidey/whodb/core/src/common" + "github.com/clidey/whodb/core/src/engine" +) + +func (p *MySQLPlugin) DeleteRow(config *engine.PluginConfig, schema string, storageUnit string, values map[string]string) (bool, error) { + db, err := DB(config) + if err != nil { + return false, err + } + + sqlDb, err := db.DB() + if err != nil { + return false, err + } + defer sqlDb.Close() + + pkColumns, err := getPrimaryKeyColumns(db, schema, storageUnit) + if err != nil { + return false, err + } + + columnTypes, err := getColumnTypes(db, schema, storageUnit) + if err != nil { + return false, err + } + + conditions := make(map[string]interface{}) + convertedValues := make(map[string]interface{}) + for column, strValue := range values { + columnType, exists := columnTypes[column] + if !exists { + return false, fmt.Errorf("column '%s' does not exist in table %s", column, storageUnit) + } + + convertedValue, err := convertStringValue(strValue, columnType) + if err != nil { + return false, fmt.Errorf("failed to convert value for column '%s': %v", column, err) + } + + if common.ContainsString(pkColumns, column) { + conditions[column] = convertedValue + } else { + convertedValues[column] = convertedValue + } + } + + tableName := fmt.Sprintf("%s.%s", schema, storageUnit) + dbConditions := db.Table(tableName) + for key, value := range conditions { + dbConditions = dbConditions.Where(fmt.Sprintf("%s = ?", key), value) + } + + result := dbConditions.Table(tableName).Delete(convertedValues) + if result.Error != nil { + return false, result.Error + } + + if result.RowsAffected == 0 { + return false, errors.New("no rows were deleted") + } + + return true, nil +} diff --git a/core/src/plugins/mysql/mysql.go b/core/src/plugins/mysql/mysql.go index 24c4c46..5cc8263 100644 --- a/core/src/plugins/mysql/mysql.go +++ b/core/src/plugins/mysql/mysql.go @@ -133,9 +133,9 @@ func (p *MySQLPlugin) GetStorageUnits(config *engine.PluginConfig, schema string func getTableSchema(db *gorm.DB, schema string) (map[string][]engine.Record, error) { var result []struct { - TableName string `gorm:"column:table_name"` - ColumnName string `gorm:"column:column_name"` - DataType string `gorm:"column:data_type"` + TableName string `gorm:"column:TABLE_NAME"` + ColumnName string `gorm:"column:COLUMN_NAME"` + DataType string `gorm:"column:DATA_TYPE"` } query := fmt.Sprintf(` @@ -239,6 +239,7 @@ func NewMySQLPlugin() *engine.Plugin { PluginFunctions: &MySQLPlugin{}, } } + func NewMyMariaDBPlugin() *engine.Plugin { return &engine.Plugin{ Type: engine.DatabaseType_MariaDB, diff --git a/core/src/plugins/mysql/update.go b/core/src/plugins/mysql/update.go index 44721e8..52369b3 100644 --- a/core/src/plugins/mysql/update.go +++ b/core/src/plugins/mysql/update.go @@ -3,11 +3,9 @@ package mysql import ( "errors" "fmt" - "strconv" "github.com/clidey/whodb/core/src/common" "github.com/clidey/whodb/core/src/engine" - "gorm.io/gorm" ) func (p *MySQLPlugin) UpdateStorageUnit(config *engine.PluginConfig, schema string, storageUnit string, values map[string]string) (bool, error) { @@ -55,7 +53,7 @@ func (p *MySQLPlugin) UpdateStorageUnit(config *engine.PluginConfig, schema stri tableName := fmt.Sprintf("%s.%s", schema, storageUnit) dbConditions := db.Table(tableName) for key, value := range conditions { - dbConditions = dbConditions.Where(fmt.Sprintf("%s = ?", key), value) + dbConditions = dbConditions.Where(fmt.Sprintf("\"%s\" = ?", key), value) } result := dbConditions.Table(tableName).Updates(convertedValues) @@ -69,80 +67,3 @@ func (p *MySQLPlugin) UpdateStorageUnit(config *engine.PluginConfig, schema stri return true, nil } - -func getPrimaryKeyColumns(db *gorm.DB, schema string, tableName string) ([]string, error) { - var primaryKeys []string - query := ` - SELECT k.column_name - FROM information_schema.table_constraints t - JOIN information_schema.key_column_usage k - USING (constraint_name, table_schema, table_name) - WHERE t.constraint_type = 'PRIMARY KEY' - AND t.table_schema = ? - AND t.table_name = ?; - ` - rows, err := db.Raw(query, schema, tableName).Rows() - if err != nil { - return nil, err - } - defer rows.Close() - - for rows.Next() { - var pkColumn string - if err := rows.Scan(&pkColumn); err != nil { - return nil, err - } - primaryKeys = append(primaryKeys, pkColumn) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - if len(primaryKeys) == 0 { - return nil, fmt.Errorf("no primary key found for table %s", tableName) - } - - return primaryKeys, nil -} - -func getColumnTypes(db *gorm.DB, schema, tableName string) (map[string]string, error) { - columnTypes := make(map[string]string) - query := ` - SELECT column_name, data_type - FROM information_schema.columns - WHERE table_schema = ? AND table_name = ?; - ` - rows, err := db.Raw(query, schema, tableName).Rows() - if err != nil { - return nil, err - } - defer rows.Close() - - for rows.Next() { - var columnName, dataType string - if err := rows.Scan(&columnName, &dataType); err != nil { - return nil, err - } - columnTypes[columnName] = dataType - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return columnTypes, nil -} - -func convertStringValue(value, columnType string) (interface{}, error) { - switch columnType { - case "int", "bigint", "smallint", "tinyint", "mediumint": - return strconv.Atoi(value) - case "boolean", "bit": - return strconv.ParseBool(value) - case "float", "double", "decimal": - return strconv.ParseFloat(value, 64) - default: - return value, nil - } -} diff --git a/core/src/plugins/mysql/utils.go b/core/src/plugins/mysql/utils.go new file mode 100644 index 0000000..7d7dd06 --- /dev/null +++ b/core/src/plugins/mysql/utils.go @@ -0,0 +1,85 @@ +package mysql + +import ( + "fmt" + "strconv" + + "gorm.io/gorm" +) + +func getPrimaryKeyColumns(db *gorm.DB, schema string, tableName string) ([]string, error) { + var primaryKeys []string + query := ` + SELECT k.column_name + FROM information_schema.table_constraints t + JOIN information_schema.key_column_usage k + USING (constraint_name, table_schema, table_name) + WHERE t.constraint_type = 'PRIMARY KEY' + AND t.table_schema = ? + AND t.table_name = ?; + ` + rows, err := db.Raw(query, schema, tableName).Rows() + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var pkColumn string + if err := rows.Scan(&pkColumn); err != nil { + return nil, err + } + primaryKeys = append(primaryKeys, pkColumn) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + if len(primaryKeys) == 0 { + return nil, fmt.Errorf("no primary key found for table %s", tableName) + } + + return primaryKeys, nil +} + +func getColumnTypes(db *gorm.DB, schema, tableName string) (map[string]string, error) { + columnTypes := make(map[string]string) + query := ` + SELECT column_name, data_type + FROM information_schema.columns + WHERE table_schema = ? AND table_name = ?; + ` + rows, err := db.Raw(query, schema, tableName).Rows() + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var columnName, dataType string + if err := rows.Scan(&columnName, &dataType); err != nil { + return nil, err + } + columnTypes[columnName] = dataType + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return columnTypes, nil +} + +func convertStringValue(value, columnType string) (interface{}, error) { + switch columnType { + case "int", "bigint", "smallint", "tinyint", "mediumint": + return strconv.Atoi(value) + case "boolean", "bit": + return strconv.ParseBool(value) + case "float", "double", "decimal": + return strconv.ParseFloat(value, 64) + default: + return value, nil + } +} diff --git a/core/src/plugins/postgres/chat.go b/core/src/plugins/postgres/chat.go new file mode 100644 index 0000000..81c8086 --- /dev/null +++ b/core/src/plugins/postgres/chat.go @@ -0,0 +1,66 @@ +package postgres + +import ( + "fmt" + "strings" + + "github.com/clidey/whodb/core/src/common" + "github.com/clidey/whodb/core/src/engine" + "github.com/clidey/whodb/core/src/llm" +) + +func (p *PostgresPlugin) Chat(config *engine.PluginConfig, schema string, model string, previousConversation string, query string) ([]*engine.ChatMessage, error) { + db, err := DB(config) + if err != nil { + return nil, err + } + sqlDb, err := db.DB() + if err != nil { + return nil, err + } + defer sqlDb.Close() + + tableFields, err := getTableSchema(db, schema) + if err != nil { + return nil, err + } + + tableDetails := strings.Builder{} + for tableName, fields := range tableFields { + tableDetails.WriteString(fmt.Sprintf("table: %v\n", tableName)) + for _, field := range fields { + tableDetails.WriteString(fmt.Sprintf("- %v (%v)\n", field.Key, field.Value)) + } + } + + context := tableDetails.String() + + completeQuery := fmt.Sprintf(common.RawSQLQueryPrompt, "Postgres", schema, context, previousConversation, query, "Postgres") + + response, err := llm.Instance(llm.Ollama_LLMType).Complete(completeQuery, llm.LLMModel(model), nil) + if err != nil { + return nil, err + } + + chats := common.ExtractCodeFromResponse(*response) + chatMessages := []*engine.ChatMessage{} + for _, chat := range chats { + var result *engine.GetRowsResult + chatType := "message" + if chat.Type == "sql" { + rowResult, err := p.RawExecute(config, chat.Text) + if err != nil { + return nil, err + } + chatType = "sql" + result = rowResult + } + chatMessages = append(chatMessages, &engine.ChatMessage{ + Type: chatType, + Result: result, + Text: chat.Text, + }) + } + + return chatMessages, nil +} diff --git a/core/src/plugins/postgres/delete.go b/core/src/plugins/postgres/delete.go new file mode 100644 index 0000000..b2e6668 --- /dev/null +++ b/core/src/plugins/postgres/delete.go @@ -0,0 +1,68 @@ +package postgres + +import ( + "errors" + "fmt" + + "github.com/clidey/whodb/core/src/common" + "github.com/clidey/whodb/core/src/engine" +) + +func (p *PostgresPlugin) DeleteRow(config *engine.PluginConfig, schema string, storageUnit string, values map[string]string) (bool, error) { + db, err := DB(config) + if err != nil { + return false, err + } + + sqlDb, err := db.DB() + if err != nil { + return false, err + } + defer sqlDb.Close() + + pkColumns, err := getPrimaryKeyColumns(db, schema, storageUnit) + if err != nil { + return false, err + } + + columnTypes, err := getColumnTypes(db, schema, storageUnit) + if err != nil { + return false, err + } + + conditions := make(map[string]interface{}) + convertedValues := make(map[string]interface{}) + for column, strValue := range values { + columnType, exists := columnTypes[column] + if !exists { + return false, fmt.Errorf("column '%s' does not exist in table %s", column, storageUnit) + } + + convertedValue, err := convertStringValue(strValue, columnType) + if err != nil { + return false, fmt.Errorf("failed to convert value for column '%s': %v", column, err) + } + + if common.ContainsString(pkColumns, column) { + conditions[column] = convertedValue + } else { + convertedValues[column] = convertedValue + } + } + + tableName := fmt.Sprintf("%s.%s", schema, storageUnit) + dbConditions := db.Table(tableName) + for key, value := range conditions { + dbConditions = dbConditions.Where(fmt.Sprintf("\"%s\" = ?", key), value) + } + result := dbConditions.Table(tableName).Delete(convertedValues) + if result.Error != nil { + return false, result.Error + } + + if result.RowsAffected == 0 { + return false, errors.New("no rows were deleted") + } + + return true, nil +} diff --git a/core/src/plugins/postgres/postgres.go b/core/src/plugins/postgres/postgres.go index 4d6aa4f..5ec8d7e 100644 --- a/core/src/plugins/postgres/postgres.go +++ b/core/src/plugins/postgres/postgres.go @@ -5,6 +5,7 @@ import ( "fmt" "log" + "github.com/clidey/whodb/core/src/common" "github.com/clidey/whodb/core/src/engine" "gorm.io/gorm" ) @@ -162,7 +163,16 @@ func getTableSchema(db *gorm.DB, schema string) (map[string][]engine.Record, err } func (p *PostgresPlugin) GetRows(config *engine.PluginConfig, schema string, storageUnit string, where string, pageSize int, pageOffset int) (*engine.GetRowsResult, error) { - query := fmt.Sprintf("SELECT * FROM \"%v\".\"%s\"", schema, storageUnit) + db, err := DB(config) + if err != nil { + return nil, err + } + sortKeyRes, err := getPrimaryKeyColumns(db, schema, storageUnit) + if err != nil { + return nil, err + } + quotedKeys := common.JoinWithQuotes(sortKeyRes) + query := fmt.Sprintf("SELECT * FROM \"%v\".\"%s\" order by %v asc", schema, storageUnit, quotedKeys) if len(where) > 0 { query = fmt.Sprintf("%v WHERE %v", query, where) } diff --git a/core/src/plugins/postgres/update.go b/core/src/plugins/postgres/update.go index ce0a3c0..f13ed3e 100644 --- a/core/src/plugins/postgres/update.go +++ b/core/src/plugins/postgres/update.go @@ -3,13 +3,9 @@ package postgres import ( "errors" "fmt" - "strconv" - "time" "github.com/clidey/whodb/core/src/common" "github.com/clidey/whodb/core/src/engine" - "github.com/google/uuid" - "gorm.io/gorm" ) func (p *PostgresPlugin) UpdateStorageUnit(config *engine.PluginConfig, schema string, storageUnit string, values map[string]string) (bool, error) { @@ -57,7 +53,7 @@ func (p *PostgresPlugin) UpdateStorageUnit(config *engine.PluginConfig, schema s tableName := fmt.Sprintf("%s.%s", schema, storageUnit) dbConditions := db.Table(tableName) for key, value := range conditions { - dbConditions = dbConditions.Where(fmt.Sprintf("%s = ?", key), value) + dbConditions = dbConditions.Where(fmt.Sprintf("\"%s\" = ?", key), value) } result := dbConditions.Table(tableName).Updates(convertedValues) @@ -71,97 +67,3 @@ func (p *PostgresPlugin) UpdateStorageUnit(config *engine.PluginConfig, schema s return true, nil } - -func getPrimaryKeyColumns(db *gorm.DB, schema string, tableName string) ([]string, error) { - var primaryKeys []string - query := ` - SELECT a.attname - FROM pg_index i - JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) - JOIN pg_class c ON c.oid = i.indrelid - JOIN pg_namespace n ON n.oid = c.relnamespace - WHERE n.nspname = ? AND c.relname = ? AND i.indisprimary; - ` - rows, err := db.Raw(query, schema, tableName).Rows() - if err != nil { - return nil, err - } - defer rows.Close() - - for rows.Next() { - var pkColumn string - if err := rows.Scan(&pkColumn); err != nil { - return nil, err - } - primaryKeys = append(primaryKeys, pkColumn) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - if len(primaryKeys) == 0 { - return nil, fmt.Errorf("no primary key found for table %s", tableName) - } - - return primaryKeys, nil -} - -func getColumnTypes(db *gorm.DB, schema, tableName string) (map[string]string, error) { - columnTypes := make(map[string]string) - query := ` - SELECT column_name, data_type - FROM information_schema.columns - WHERE table_schema = ? AND table_name = ?; - ` - rows, err := db.Raw(query, schema, tableName).Rows() - if err != nil { - return nil, err - } - defer rows.Close() - - for rows.Next() { - var columnName, dataType string - if err := rows.Scan(&columnName, &dataType); err != nil { - return nil, err - } - columnTypes[columnName] = dataType - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return columnTypes, nil -} - -func convertStringValue(value, columnType string) (interface{}, error) { - switch columnType { - case "integer", "smallint", "bigint": - return strconv.ParseInt(value, 10, 64) - case "numeric", "real", "double precision": - return strconv.ParseFloat(value, 64) - case "boolean": - return strconv.ParseBool(value) - case "uuid": - _, err := uuid.Parse(value) - if err != nil { - return nil, fmt.Errorf("invalid UUID format: %v", err) - } - return value, nil - case "date": - _, err := time.Parse("2006-01-02", value) - if err != nil { - return nil, fmt.Errorf("invalid date format: %v", err) - } - return value, nil - case "timestamp", "timestamp with time zone", "timestamp without time zone": - _, err := time.Parse(time.RFC3339, value) - if err != nil { - return nil, fmt.Errorf("invalid timestamp format: %v", err) - } - return value, nil - default: - return value, nil - } -} diff --git a/core/src/plugins/postgres/utils.go b/core/src/plugins/postgres/utils.go new file mode 100644 index 0000000..92d5181 --- /dev/null +++ b/core/src/plugins/postgres/utils.go @@ -0,0 +1,104 @@ +package postgres + +import ( + "fmt" + "strconv" + "time" + + "github.com/google/uuid" + "gorm.io/gorm" +) + +func getPrimaryKeyColumns(db *gorm.DB, schema string, tableName string) ([]string, error) { + var primaryKeys []string + query := ` + SELECT a.attname + FROM pg_index i + JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) + JOIN pg_class c ON c.oid = i.indrelid + JOIN pg_namespace n ON n.oid = c.relnamespace + WHERE n.nspname = ? AND c.relname = ? AND i.indisprimary; + ` + rows, err := db.Raw(query, schema, tableName).Rows() + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var pkColumn string + if err := rows.Scan(&pkColumn); err != nil { + return nil, err + } + primaryKeys = append(primaryKeys, pkColumn) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + if len(primaryKeys) == 0 { + return nil, fmt.Errorf("no primary key found for table %s", tableName) + } + + return primaryKeys, nil +} + +func getColumnTypes(db *gorm.DB, schema, tableName string) (map[string]string, error) { + columnTypes := make(map[string]string) + query := ` + SELECT column_name, data_type + FROM information_schema.columns + WHERE table_schema = ? AND table_name = ?; + ` + rows, err := db.Raw(query, schema, tableName).Rows() + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var columnName, dataType string + if err := rows.Scan(&columnName, &dataType); err != nil { + return nil, err + } + columnTypes[columnName] = dataType + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return columnTypes, nil +} + +func convertStringValue(value, columnType string) (interface{}, error) { + switch columnType { + case "integer", "smallint", "bigint": + return strconv.ParseInt(value, 10, 64) + case "numeric", "real", "double precision": + return strconv.ParseFloat(value, 64) + case "boolean": + return strconv.ParseBool(value) + case "uuid": + _, err := uuid.Parse(value) + if err != nil { + return nil, fmt.Errorf("invalid UUID format: %v", err) + } + return value, nil + case "date": + _, err := time.Parse("2006-01-02", value) + if err != nil { + return nil, fmt.Errorf("invalid date format: %v", err) + } + return value, nil + case "timestamp", "timestamp with time zone", "timestamp without time zone": + _, err := time.Parse(time.RFC3339, value) + if err != nil { + return nil, fmt.Errorf("invalid timestamp format: %v", err) + } + return value, nil + default: + return value, nil + } +} diff --git a/core/src/plugins/redis/delete.go b/core/src/plugins/redis/delete.go new file mode 100644 index 0000000..88a1576 --- /dev/null +++ b/core/src/plugins/redis/delete.go @@ -0,0 +1,81 @@ +package redis + +import ( + "context" + "errors" + "fmt" + "github.com/clidey/whodb/core/src/engine" + "strconv" +) + +func (p *RedisPlugin) DeleteRow(config *engine.PluginConfig, schema string, storageUnit string, values map[string]string) (bool, error) { + client, err := DB(config) + if err != nil { + return false, err + } + defer client.Close() + + ctx := context.Background() + + keyType, err := client.Type(ctx, storageUnit).Result() + if err != nil { + return false, err + } + + switch keyType { + case "string": + // Deleting the entire string key + err := client.Del(ctx, storageUnit).Err() + if err != nil { + return false, err + } + case "hash": + // Deleting a specific field from a hash + field, ok := values["field"] + if !ok { + return false, errors.New("missing 'field' for hash deletion") + } + err := client.HDel(ctx, storageUnit, field).Err() + if err != nil { + return false, err + } + case "list": + // Removing an element from a list + indexStr, ok := values["index"] + if !ok { + return false, errors.New("missing 'index' for list deletion") + } + index, err := strconv.ParseInt(indexStr, 10, 64) + if err != nil { + return false, errors.New("unable to convert index to int") + } + value := client.LIndex(ctx, storageUnit, index).Val() + if err := client.LRem(ctx, storageUnit, 1, value).Err(); err != nil { + return false, errors.New("unable to remove the list item") + } + case "set": + // Removing a specific member from a set + member, ok := values["member"] + if !ok { + return false, errors.New("missing 'member' for set deletion") + } + err := client.SRem(ctx, storageUnit, member).Err() + if err != nil { + return false, err + } + case "zset": + // Removing a specific member from a sorted set + member, ok := values["member"] + if !ok { + return false, errors.New("missing 'member' for sorted set deletion") + } + err := client.ZRem(ctx, storageUnit, member).Err() + if err != nil { + return false, err + } + default: + return false, fmt.Errorf("unsupported Redis data type: %s", keyType) + } + + return true, nil +} diff --git a/core/src/plugins/redis/redis.go b/core/src/plugins/redis/redis.go index 1e22bc5..06a9750 100644 --- a/core/src/plugins/redis/redis.go +++ b/core/src/plugins/redis/redis.go @@ -214,6 +214,10 @@ func (p *RedisPlugin) RawExecute(config *engine.PluginConfig, query string) (*en return nil, errors.New("unsupported operation for Redis") } +func (p *RedisPlugin) Chat(config *engine.PluginConfig, schema string, model string, previousConversation string, query string) ([]*engine.ChatMessage, error) { + return nil, errors.ErrUnsupported +} + func NewRedisPlugin() *engine.Plugin { return &engine.Plugin{ Type: engine.DatabaseType_Redis, diff --git a/core/src/plugins/sqlite3/chat.go b/core/src/plugins/sqlite3/chat.go new file mode 100644 index 0000000..c82d209 --- /dev/null +++ b/core/src/plugins/sqlite3/chat.go @@ -0,0 +1,66 @@ +package sqlite3 + +import ( + "fmt" + "strings" + + "github.com/clidey/whodb/core/src/common" + "github.com/clidey/whodb/core/src/engine" + "github.com/clidey/whodb/core/src/llm" +) + +func (p *Sqlite3Plugin) Chat(config *engine.PluginConfig, schema string, model string, previousConversation string, query string) ([]*engine.ChatMessage, error) { + db, err := DB(config) + if err != nil { + return nil, err + } + sqlDb, err := db.DB() + if err != nil { + return nil, err + } + defer sqlDb.Close() + + tableFields, err := getTableSchema(db) + if err != nil { + return nil, err + } + + tableDetails := strings.Builder{} + for tableName, fields := range tableFields { + tableDetails.WriteString(fmt.Sprintf("table: %v\n", tableName)) + for _, field := range fields { + tableDetails.WriteString(fmt.Sprintf("- %v (%v)\n", field.Key, field.Value)) + } + } + + context := tableDetails.String() + + completeQuery := fmt.Sprintf(common.RawSQLQueryPrompt, "SQLite3", schema, context, previousConversation, query, "SQLite3") + + response, err := llm.Instance(llm.Ollama_LLMType).Complete(completeQuery, llm.LLMModel(model), nil) + if err != nil { + return nil, err + } + + chats := common.ExtractCodeFromResponse(*response) + chatMessages := []*engine.ChatMessage{} + for _, chat := range chats { + var result *engine.GetRowsResult + chatType := "message" + if chat.Type == "sql" { + rowResult, err := p.RawExecute(config, chat.Text) + if err != nil { + return nil, err + } + chatType = "sql" + result = rowResult + } + chatMessages = append(chatMessages, &engine.ChatMessage{ + Type: chatType, + Result: result, + Text: chat.Text, + }) + } + + return chatMessages, nil +} diff --git a/core/src/plugins/sqlite3/delete.go b/core/src/plugins/sqlite3/delete.go new file mode 100644 index 0000000..4daa667 --- /dev/null +++ b/core/src/plugins/sqlite3/delete.go @@ -0,0 +1,63 @@ +package sqlite3 + +import ( + "errors" + "fmt" + + "github.com/clidey/whodb/core/src/common" + "github.com/clidey/whodb/core/src/engine" +) + +func (p *Sqlite3Plugin) DeleteRow(config *engine.PluginConfig, schema string, storageUnit string, values map[string]string) (bool, error) { + db, err := DB(config) + if err != nil { + return false, err + } + + sqlDb, err := db.DB() + if err != nil { + return false, err + } + defer sqlDb.Close() + + pkColumns, columnTypes, err := getTableInfo(db, storageUnit) + if err != nil { + return false, err + } + + conditions := make(map[string]interface{}) + convertedValues := make(map[string]interface{}) + for column, strValue := range values { + columnType, exists := columnTypes[column] + if !exists { + return false, fmt.Errorf("column '%s' does not exist in table %s", column, storageUnit) + } + + convertedValue, err := convertStringValue(strValue, columnType) + if err != nil { + return false, fmt.Errorf("failed to convert value for column '%s': %v", column, err) + } + + if common.ContainsString(pkColumns, column) { + conditions[column] = convertedValue + } else { + convertedValues[column] = convertedValue + } + } + + dbConditions := db.Table(storageUnit) + for key, value := range conditions { + dbConditions = dbConditions.Where(fmt.Sprintf("%s = ?", key), value) + } + + result := dbConditions.Table(storageUnit).Delete(convertedValues) + if result.Error != nil { + return false, result.Error + } + + if result.RowsAffected == 0 { + return false, errors.New("no rows were deleted") + } + + return true, nil +} diff --git a/core/src/plugins/sqlite3/update.go b/core/src/plugins/sqlite3/update.go index b796a93..05b768b 100644 --- a/core/src/plugins/sqlite3/update.go +++ b/core/src/plugins/sqlite3/update.go @@ -3,12 +3,9 @@ package sqlite3 import ( "errors" "fmt" - "strconv" - "time" "github.com/clidey/whodb/core/src/common" "github.com/clidey/whodb/core/src/engine" - "gorm.io/gorm" ) func (p *Sqlite3Plugin) UpdateStorageUnit(config *engine.PluginConfig, schema string, storageUnit string, values map[string]string) (bool, error) { @@ -50,7 +47,7 @@ func (p *Sqlite3Plugin) UpdateStorageUnit(config *engine.PluginConfig, schema st dbConditions := db.Table(storageUnit) for key, value := range conditions { - dbConditions = dbConditions.Where(fmt.Sprintf("%s = ?", key), value) + dbConditions = dbConditions.Where(fmt.Sprintf("\"%s\" = ?", key), value) } result := dbConditions.Table(storageUnit).Updates(convertedValues) @@ -64,67 +61,3 @@ func (p *Sqlite3Plugin) UpdateStorageUnit(config *engine.PluginConfig, schema st return true, nil } - -func getTableInfo(db *gorm.DB, tableName string) ([]string, map[string]string, error) { - var primaryKeys []string - columnTypes := make(map[string]string) - pragmaQuery := fmt.Sprintf("PRAGMA table_info(%s)", tableName) - rows, err := db.Raw(pragmaQuery, tableName).Rows() - if err != nil { - return nil, nil, err - } - defer rows.Close() - - for rows.Next() { - var ( - cid int - name string - type_ string - notnull int - dfltValue interface{} - pk int - ) - if err := rows.Scan(&cid, &name, &type_, ¬null, &dfltValue, &pk); err != nil { - return nil, nil, err - } - columnTypes[name] = type_ - if pk == 1 { - primaryKeys = append(primaryKeys, name) - } - } - - if err := rows.Err(); err != nil { - return nil, nil, err - } - - if len(primaryKeys) == 0 { - return nil, nil, fmt.Errorf("no primary key found for table %s", tableName) - } - - return primaryKeys, columnTypes, nil -} - -func convertStringValue(value, columnType string) (interface{}, error) { - switch columnType { - case "INTEGER": - return strconv.ParseInt(value, 10, 64) - case "REAL": - return strconv.ParseFloat(value, 64) - case "BOOLEAN": - return strconv.ParseBool(value) - case "DATE": - _, err := time.Parse("2006-01-02", value) - if err != nil { - return nil, fmt.Errorf("invalid date format: %v", err) - } - return value, nil - case "DATETIME": - _, err := time.Parse(time.RFC3339, value) - if err != nil { - return nil, fmt.Errorf("invalid datetime format: %v", err) - } - return value, nil - default: - return value, nil - } -} diff --git a/core/src/plugins/sqlite3/utils.go b/core/src/plugins/sqlite3/utils.go new file mode 100644 index 0000000..808c4a5 --- /dev/null +++ b/core/src/plugins/sqlite3/utils.go @@ -0,0 +1,73 @@ +package sqlite3 + +import ( + "fmt" + "strconv" + "time" + + "gorm.io/gorm" +) + +func getTableInfo(db *gorm.DB, tableName string) ([]string, map[string]string, error) { + var primaryKeys []string + columnTypes := make(map[string]string) + pragmaQuery := fmt.Sprintf("PRAGMA table_info(%s)", tableName) + rows, err := db.Raw(pragmaQuery, tableName).Rows() + if err != nil { + return nil, nil, err + } + defer rows.Close() + + for rows.Next() { + var ( + cid int + name string + type_ string + notnull int + dfltValue interface{} + pk int + ) + if err := rows.Scan(&cid, &name, &type_, ¬null, &dfltValue, &pk); err != nil { + return nil, nil, err + } + columnTypes[name] = type_ + if pk == 1 { + primaryKeys = append(primaryKeys, name) + } + } + + if err := rows.Err(); err != nil { + return nil, nil, err + } + + if len(primaryKeys) == 0 { + return nil, nil, fmt.Errorf("no primary key found for table %s", tableName) + } + + return primaryKeys, columnTypes, nil +} + +func convertStringValue(value, columnType string) (interface{}, error) { + switch columnType { + case "INTEGER": + return strconv.ParseInt(value, 10, 64) + case "REAL": + return strconv.ParseFloat(value, 64) + case "BOOLEAN": + return strconv.ParseBool(value) + case "DATE": + _, err := time.Parse("2006-01-02", value) + if err != nil { + return nil, fmt.Errorf("invalid date format: %v", err) + } + return value, nil + case "DATETIME": + _, err := time.Parse(time.RFC3339, value) + if err != nil { + return nil, fmt.Errorf("invalid datetime format: %v", err) + } + return value, nil + default: + return value, nil + } +} diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 3a1d43f..7522384 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,9 +1,10 @@ +import classNames from "classnames"; import { map } from "lodash"; -import { Navigate, Route, Routes } from "react-router-dom"; +import { Route, Routes } from "react-router-dom"; import { Notifications } from './components/notifications'; -import { InternalRoutes, PrivateRoute, PublicRoutes, getRoutes } from './config/routes'; +import { PrivateRoute, PublicRoutes, getRoutes } from './config/routes'; +import { NavigateToDefault } from "./pages/chat/default-chat-route"; import { useAppSelector } from "./store/hooks"; -import classNames from "classnames"; export const App = () => { const darkModeEnabled = useAppSelector(state => state.global.theme === "dark"); @@ -17,7 +18,7 @@ export const App = () => { {map(getRoutes(), route => ( ))} - } /> + } /> diff --git a/frontend/src/app.tsx b/frontend/src/app.tsx index 3a1d43f..7522384 100644 --- a/frontend/src/app.tsx +++ b/frontend/src/app.tsx @@ -1,9 +1,10 @@ +import classNames from "classnames"; import { map } from "lodash"; -import { Navigate, Route, Routes } from "react-router-dom"; +import { Route, Routes } from "react-router-dom"; import { Notifications } from './components/notifications'; -import { InternalRoutes, PrivateRoute, PublicRoutes, getRoutes } from './config/routes'; +import { PrivateRoute, PublicRoutes, getRoutes } from './config/routes'; +import { NavigateToDefault } from "./pages/chat/default-chat-route"; import { useAppSelector } from "./store/hooks"; -import classNames from "classnames"; export const App = () => { const darkModeEnabled = useAppSelector(state => state.global.theme === "dark"); @@ -17,7 +18,7 @@ export const App = () => { {map(getRoutes(), route => ( ))} - } /> + } /> diff --git a/frontend/src/components/button.tsx b/frontend/src/components/button.tsx index 5c36442..9e124e4 100644 --- a/frontend/src/components/button.tsx +++ b/frontend/src/components/button.tsx @@ -23,7 +23,7 @@ export const Button: FC = (props) => { {props.label} {cloneElement(props.icon, { - className: classNames("w-4 h-4 stroke-gray-600 dark:stroke-white", props.iconClassName), + className: twMerge(classNames("w-4 h-4 stroke-gray-600 dark:stroke-white", props.iconClassName)), })} } diff --git a/frontend/src/components/dropdown.tsx b/frontend/src/components/dropdown.tsx index 1e138b2..62ddf6e 100644 --- a/frontend/src/components/dropdown.tsx +++ b/frontend/src/components/dropdown.tsx @@ -74,8 +74,8 @@ export const Dropdown: FC = (props) => { })}>
    { - props.items.map((item) => ( -
  • ( +
  • handleClick(item)}>
    {props.value?.id === item.id ? Icons.CheckCircle : item.icon}
    diff --git a/frontend/src/components/editor.tsx b/frontend/src/components/editor.tsx index 1f14cfd..4dfb7b1 100644 --- a/frontend/src/components/editor.tsx +++ b/frontend/src/components/editor.tsx @@ -14,7 +14,7 @@ languages.register({ id: 'sql' }); type ICodeEditorProps = { value: string; - setValue: (value: string) => void; + setValue?: (value: string) => void; language?: "sql" | "markdown" | "json"; options?: EditorProps["options"]; onRun?: () => void; @@ -29,7 +29,9 @@ export const CodeEditor: FC = ({ value, setValue, language, op const handleEditorDidMount: OnMount = useCallback(editor => { editorRef.current = editor; - }, []); + editor.setSelection({ startLineNumber: 1, startColumn: value.length+1, endLineNumber: 1, endColumn: value.length+1 }); + editor.focus(); + }, [value.length]); const handlePreviewToggle = useCallback(async () => { setShowPreview(p => !p); @@ -55,7 +57,7 @@ export const CodeEditor: FC = ({ value, setValue, language, op const handleChange = useCallback((newValue: string | undefined) => { if (newValue != null) { - setValue(newValue); + setValue?.(newValue); } }, [setValue]); diff --git a/frontend/src/components/hooks.tsx b/frontend/src/components/hooks.tsx index 1bdae75..ba67302 100644 --- a/frontend/src/components/hooks.tsx +++ b/frontend/src/components/hooks.tsx @@ -1,11 +1,17 @@ import { useCallback, useEffect, useRef, useState } from "react"; -export const useExportToCSV = (columns: string[], rows: Record[]) => { +export const useExportToCSV = (columns: string[], rows: Record[], specificIndexes: number[] = []) => { return useCallback(() => { + let selectedRows: Record[]; + if (specificIndexes.length === 0) { + selectedRows = rows; + } else { + selectedRows = specificIndexes.map(index => rows[index]); + } const csvContent = [ columns.join(','), - ...rows.map(row => columns.map(col => row[col]).join(",")) + ...selectedRows.map(row => columns.map(col => row[col]).join(",")) ].join('\n'); const blob = new Blob([csvContent], { type: 'text/csv;charset=utf-8;' }); @@ -20,7 +26,7 @@ export const useExportToCSV = (columns: string[], rows: Record[] link.click(); document.body.removeChild(link); } - }, [columns, rows]); + }, [columns, rows, specificIndexes]); }; type ILongPressProps = { diff --git a/frontend/src/components/icons.tsx b/frontend/src/components/icons.tsx index 6c5e90f..1dea78d 100644 --- a/frontend/src/components/icons.tsx +++ b/frontend/src/components/icons.tsx @@ -21,7 +21,7 @@ export const Icons = { Cancel: , - Tables: + Tables: , DoubleRightArrow: @@ -42,7 +42,7 @@ export const Icons = { Database: , - DocumentMagnify: + DocumentMagnify: , Console: @@ -72,23 +72,57 @@ export const Icons = { DownCaret: , - PlusCircle: + PlusCircle: , - Adjustments: + Adjustments: , - Text: + Text: , - Code: + Code: , + Chat: + + , + Users: + + , + Feedback: + + , + Sales: + + , + Products: + + , + Inventory: + + , + Financials: + + , + Marketing: + + , + Projects: + + , + Analytics: + + + , + Compliance: + + , Logos: { Postgres: , MySQL: , MariaDB: , - Sqlite3: , + Sqlite3: , MongoDB: , Redis: , ElasticSearch: , diff --git a/frontend/src/components/input.tsx b/frontend/src/components/input.tsx index 8b3c33a..85e8a9a 100644 --- a/frontend/src/components/input.tsx +++ b/frontend/src/components/input.tsx @@ -1,5 +1,5 @@ import classNames from "classnames"; -import { ChangeEvent, ChangeEventHandler, DetailedHTMLProps, FC, InputHTMLAttributes, KeyboardEvent, cloneElement, useCallback, useState } from "react"; +import { ChangeEvent, ChangeEventHandler, DetailedHTMLProps, FC, InputHTMLAttributes, KeyboardEventHandler, cloneElement, useCallback, useState } from "react"; import { twMerge } from "tailwind-merge"; import { Icons } from "./icons"; @@ -18,20 +18,24 @@ type InputProps = { value: string; setValue?: (value: string) => void; type?: "text" | "password"; + onSubmit?: () => void; } -export const Input: FC = ({ value, setValue, type, placeholder, inputProps = {} }) => { +export const Input: FC = ({ value, setValue, type, placeholder, onSubmit, inputProps = {} }) => { const handleChange: ChangeEventHandler = useCallback((e) => { setValue?.(e.target.value); inputProps.onChange?.(e); }, [inputProps, setValue]); - const handleKeyDown = useCallback((e: KeyboardEvent) => { - inputProps.onKeyDown?.(e); - }, [inputProps]); - + const handleHandleKeyUp: KeyboardEventHandler = useCallback((e) => { + if (e.key === "Enter") { + onSubmit?.(); + } + inputProps?.onKeyUp?.(e); + }, [inputProps, onSubmit]); + return } @@ -77,3 +81,19 @@ export const ToggleInput: FC = ({ value, setValue }) => { ); } + + +type ICheckBoxInputProps = { + value: boolean; + setValue?: (value: boolean) => void; +} + +export const CheckBoxInput: FC = ({ value, setValue }) => { + const handleChange = useCallback((e: ChangeEvent) => { + setValue?.(e.target.checked); + }, [setValue]); + + return ( + + ); +} diff --git a/frontend/src/components/loading.tsx b/frontend/src/components/loading.tsx index 364d554..9cab4cd 100644 --- a/frontend/src/components/loading.tsx +++ b/frontend/src/components/loading.tsx @@ -5,13 +5,14 @@ import { twMerge } from "tailwind-merge"; type ILoadingProps = { className?: string; hideText?: boolean; + containerClassName?: string; loadingText?: string; textClassName?: string; } -export const Loading: FC = ({ className, hideText, loadingText, textClassName }) => { +export const Loading: FC = ({ containerClassName, className, hideText, loadingText, textClassName }) => { return ( -
    +
    - + {cloneElement(Icons.Search, { className: "w-4 h-4 absolute right-2 top-1/2 -translate-y-1/2 stroke-gray-500 dark:stroke-neutral-500 cursor-pointer transition-all hover:scale-110 rounded-full group-hover/search-input:opacity-10", })} diff --git a/frontend/src/components/sidebar/sidebar.tsx b/frontend/src/components/sidebar/sidebar.tsx index b2d8170..4045b26 100644 --- a/frontend/src/components/sidebar/sidebar.tsx +++ b/frontend/src/components/sidebar/sidebar.tsx @@ -7,7 +7,7 @@ import { useDispatch } from "react-redux"; import { Link, useLocation, useNavigate } from "react-router-dom"; import { twMerge } from "tailwind-merge"; import { InternalRoutes, PublicRoutes } from "../../config/routes"; -import { DatabaseType, useGetDatabaseQuery, useGetSchemaQuery, useLoginMutation, useLoginWithProfileMutation } from "../../generated/graphql"; +import { DatabaseType, useGetAiModelsQuery, useGetDatabaseQuery, useGetSchemaQuery, useLoginMutation, useLoginWithProfileMutation } from "../../generated/graphql"; import { AuthActions, LocalLoginProfile } from "../../store/auth"; import { DatabaseActions } from "../../store/database"; import { notify } from "../../store/function"; @@ -151,6 +151,7 @@ export const Sidebar: FC = () => { const pathname = useLocation().pathname; const current = useAppSelector(state => state.auth.current); const profiles = useAppSelector(state => state.auth.profiles); + const { data: aiModels } = useGetAiModelsQuery(); const { data: availableDatabases, loading: availableDatabasesLoading } = useGetDatabaseQuery({ variables: { type: current?.Type as DatabaseType, @@ -266,6 +267,14 @@ export const Sidebar: FC = () => { path: InternalRoutes.Graph.path, }, ]; + + if (!isNoSQL(current.Type) && aiModels?.AIModel != null && aiModels.AIModel.length > 0) { + routes.unshift({ + title: "Chat", + icon: Icons.Chat, + path: InternalRoutes.Chat.path, + }); + } if (!DATABASES_THAT_DONT_SUPPORT_SCRATCH_PAD.includes(current.Type as DatabaseType)) { routes.push({ title: "Scratchpad", @@ -274,7 +283,7 @@ export const Sidebar: FC = () => { }); } return routes; - }, [current]); + }, [aiModels?.AIModel, current]); const handleCollapseToggle = useCallback(() => { setCollapsed(c => !c); diff --git a/frontend/src/components/table.tsx b/frontend/src/components/table.tsx index 045a2bb..60f8ae4 100644 --- a/frontend/src/components/table.tsx +++ b/frontend/src/components/table.tsx @@ -1,18 +1,20 @@ import classNames from "classnames"; import { AnimatePresence, motion } from "framer-motion"; +import { clone, isString, values } from "lodash"; import { CSSProperties, FC, KeyboardEvent, MouseEvent, useCallback, useEffect, useMemo, useRef, useState } from "react"; import { Cell, Row, useBlockLayout, useTable } from 'react-table'; import { FixedSizeList, ListChildComponentProps } from "react-window"; import { twMerge } from "tailwind-merge"; +import { notify } from "../store/function"; import { isMarkdown, isNumeric, isValidJSON } from "../utils/functions"; import { ActionButton, AnimatedButton } from "./button"; import { Portal } from "./common"; import { CodeEditor } from "./editor"; import { useExportToCSV, useLongPress } from "./hooks"; import { Icons } from "./icons"; -import { SearchInput } from "./search"; +import { CheckBoxInput } from "./input"; import { Loading } from "./loading"; -import { clone, values } from "lodash"; +import { SearchInput } from "./search"; type IPaginationProps = { pageCount: number; @@ -85,12 +87,14 @@ const Pagination: FC = ({ pageCount, currentPage, onPageChange }; type ITDataProps = { - cell: Cell>; - onCellUpdate?: (row: Cell>) => Promise; + cell: Cell>; + onCellUpdate?: (row: Cell>) => Promise; disableEdit?: boolean; + checked?: boolean; + onRowCheck?: (value: boolean) => void; } -const TData: FC = ({ cell, onCellUpdate, disableEdit }) => { +const TData: FC = ({ cell, onCellUpdate, checked, onRowCheck, disableEdit }) => { const [changed, setChanged] = useState(false); const [editedData, setEditedData] = useState(cell.value); const [editable, setEditable] = useState(false); @@ -99,6 +103,7 @@ const TData: FC = ({ cell, onCellUpdate, disableEdit }) => { const cellRef = useRef(null); const [copied, setCopied] = useState(false); const [updating, setUpdating] = useState(false); + const [escapeAttempted, setEscapeAttempted] = useState(false); const handleChange = useCallback((value: string) => { setEditedData(value); @@ -161,7 +166,25 @@ const TData: FC = ({ cell, onCellUpdate, disableEdit }) => { }); }, [cell, editedData, onCellUpdate]); + const handleEditorEscapeButton = useCallback((e: KeyboardEvent) => { + if (e.key === "Escape" && !changed) { + handleCancel(); + } else if (e.key === "Escape" && changed) { + if (escapeAttempted) { + setEscapeAttempted(false); + handleCancel(); + } else { + setEscapeAttempted(true); + notify("You have unsaved changes, please save or cancel. Pressing Escape again will close without saving.", "warning"); + setTimeout(() => setEscapeAttempted(false), 2000); // reset it in case + } + } + }, [changed, handleCancel, escapeAttempted]); + const language = useMemo(() => { + if (editedData == null) { + return; + } if (isValidJSON(editedData)) { return "json"; } @@ -188,8 +211,21 @@ const TData: FC = ({ cell, onCellUpdate, disableEdit }) => { className={classNames("w-full h-full p-2 leading-tight focus:outline-none focus:shadow-outline appearance-none transition-all duration-300 border-solid border-gray-200 dark:border-white/5 overflow-hidden whitespace-nowrap select-none text-gray-600 dark:text-neutral-300", { "group-even/row:bg-gray-100 hover:bg-gray-300 group-even/row:hover:bg-gray-300 dark:group-even/row:bg-white/10 dark:group-odd/row:bg-white/5 dark:group-even/row:hover:bg-white/15 dark:group-odd/row:hover:bg-white/15": !editable, "bg-transparent": editable, - })} - {...longPressProps}>{editedData}
    + })}> +
    + +
    +
    + {editedData} +
    +