diff --git a/internal/threagile/analyze.go b/internal/threagile/analyze.go index 593bc69f..56f8aa6f 100644 --- a/internal/threagile/analyze.go +++ b/internal/threagile/analyze.go @@ -15,16 +15,16 @@ func (what *Threagile) initAnalyze() *Threagile { Short: "Analyze model", Aliases: []string{"analyze", "analyse", "run", "analyse-model"}, RunE: func(cmd *cobra.Command, args []string) error { - cfg := what.readConfig(cmd, what.buildTimestamp) + what.processArgs(cmd, args) commands := what.readCommands() - progressReporter := DefaultProgressReporter{Verbose: cfg.GetVerbose()} + progressReporter := DefaultProgressReporter{Verbose: what.config.GetVerbose()} - r, err := model.ReadAndAnalyzeModel(cfg, risks.GetBuiltInRiskRules(), progressReporter) + r, err := model.ReadAndAnalyzeModel(what.config, risks.GetBuiltInRiskRules(), progressReporter) if err != nil { return fmt.Errorf("failed to read and analyze model: %w", err) } - err = report.Generate(cfg, r, commands, risks.GetBuiltInRiskRules(), progressReporter) + err = report.Generate(what.config, r, commands, risks.GetBuiltInRiskRules(), progressReporter) if err != nil { return fmt.Errorf("failed to generate reports: %w", err) } diff --git a/internal/threagile/config.go b/internal/threagile/config.go index 21c31297..9326fa2f 100644 --- a/internal/threagile/config.go +++ b/internal/threagile/config.go @@ -53,8 +53,8 @@ type Config struct { RiskExcelValue RiskExcelConfig `json:"RiskExcel" yaml:"RiskExcel"` ServerModeValue bool `json:"ServerMode,omitempty" yaml:"ServerMode"` - DiagramDPIValue int `json:"DiagramDPI,omitempty" yaml:"DiagramDPI"` ServerPortValue int `json:"ServerPort,omitempty" yaml:"ServerPort"` + DiagramDPIValue int `json:"DiagramDPI,omitempty" yaml:"DiagramDPI"` GraphvizDPIValue int `json:"GraphvizDPI,omitempty" yaml:"GraphvizDPI"` MaxGraphvizDPIValue int `json:"MaxGraphvizDPI,omitempty" yaml:"MaxGraphvizDPI"` BackupHistoryFilesToKeepValue int `json:"BackupHistoryFilesToKeep,omitempty" yaml:"BackupHistoryFilesToKeep"` @@ -63,6 +63,16 @@ type Config struct { KeepDiagramSourceFilesValue bool `json:"KeepDiagramSourceFiles,omitempty" yaml:"KeepDiagramSourceFiles"` IgnoreOrphanedRiskTrackingValue bool `json:"IgnoreOrphanedRiskTracking,omitempty" yaml:"IgnoreOrphanedRiskTracking"` + SkipDataFlowDiagramValue bool `json:"SkipDataFlowDiagram,omitempty" yaml:"SkipDataFlowDiagram"` + SkipDataAssetDiagramValue bool `json:"SkipDataAssetDiagram,omitempty" yaml:"SkipDataAssetDiagram"` + SkipRisksJSONValue bool `json:"SkipRisksJSON,omitempty" yaml:"SkipRisksJSON"` + SkipTechnicalAssetsJSONValue bool `json:"SkipTechnicalAssetsJSON,omitempty" yaml:"SkipTechnicalAssetsJSON"` + SkipStatsJSONValue bool `json:"SkipStatsJSON,omitempty" yaml:"SkipStatsJSON"` + SkipRisksExcelValue bool `json:"SkipRisksExcel,omitempty" yaml:"SkipRisksExcel"` + SkipTagsExcelValue bool `json:"SkipTagsExcel,omitempty" yaml:"SkipTagsExcel"` + SkipReportPDFValue bool `json:"SkipReportPDF,omitempty" yaml:"SkipReportPDF"` + SkipReportADOCValue bool `json:"SkipReportADOC,omitempty" yaml:"SkipReportADOC"` + AttractivenessValue Attractiveness `json:"Attractiveness" yaml:"Attractiveness"` ReportConfigurationValue report.ReportConfiguation `json:"ReportConfiguration" yaml:"ReportConfiguration"` @@ -103,8 +113,8 @@ type ConfigGetter interface { GetRiskExcelShrinkColumnsToFit() bool GetRiskExcelColorText() bool GetServerMode() bool - GetDiagramDPI() int GetServerPort() int + GetDiagramDPI() int GetGraphvizDPI() int GetMinGraphvizDPI() int GetMaxGraphvizDPI() int @@ -112,13 +122,21 @@ type ConfigGetter interface { GetAddModelTitle() bool GetKeepDiagramSourceFiles() bool GetIgnoreOrphanedRiskTracking() bool + GetSkipDataFlowDiagram() bool + GetSkipDataAssetDiagram() bool + GetSkipRisksJSON() bool + GetSkipTechnicalAssetsJSON() bool + GetSkipStatsJSON() bool + GetSkipRisksExcel() bool + GetSkipTagsExcel() bool + GetSkipReportPDF() bool + GetSkipReportADOC() bool GetAttractiveness() Attractiveness GetReportConfiguration() report.ReportConfiguation GetThreagileVersion() string GetProgressReporter() types.ProgressReporter GetReportConfigurationHideChapters() map[report.ChaptersToShowHide]bool } - type ConfigSetter interface { SetVerbose(verbose bool) SetInteractive(interactive bool) @@ -132,8 +150,8 @@ type ConfigSetter interface { SetRiskRulePlugins(riskRulePlugins []string) SetSkipRiskRules(skipRiskRules []string) SetServerMode(serverMode bool) - SetDiagramDPI(diagramDPI int) SetServerPort(serverPort int) + SetDiagramDPI(diagramDPI int) SetIgnoreOrphanedRiskTracking(ignoreOrphanedRiskTracking bool) } @@ -436,7 +454,7 @@ func (c *Config) Merge(config Config, values map[string]any) { } case strings.ToLower("ServerMode"): - // not configurable via config file + c.ServerModeValue = config.ServerModeValue case strings.ToLower("DiagramDPI"): c.DiagramDPIValue = config.DiagramDPIValue @@ -718,14 +736,6 @@ func (c *Config) SetServerMode(serverMode bool) { c.ServerModeValue = serverMode } -func (c *Config) GetDiagramDPI() int { - return c.DiagramDPIValue -} - -func (c *Config) SetDiagramDPI(diagramDPI int) { - c.DiagramDPIValue = diagramDPI -} - func (c *Config) GetServerPort() int { return c.ServerPortValue } @@ -734,6 +744,14 @@ func (c *Config) SetServerPort(serverPort int) { c.ServerPortValue = serverPort } +func (c *Config) GetDiagramDPI() int { + return c.DiagramDPIValue +} + +func (c *Config) SetDiagramDPI(diagramDPI int) { + c.DiagramDPIValue = diagramDPI +} + func (c *Config) GetGraphvizDPI() int { return c.GraphvizDPIValue } @@ -766,6 +784,42 @@ func (c *Config) SetIgnoreOrphanedRiskTracking(ignoreOrphanedRiskTracking bool) c.IgnoreOrphanedRiskTrackingValue = ignoreOrphanedRiskTracking } +func (c *Config) GetSkipDataFlowDiagram() bool { + return c.SkipDataFlowDiagramValue +} + +func (c *Config) GetSkipDataAssetDiagram() bool { + return c.SkipDataAssetDiagramValue +} + +func (c *Config) GetSkipRisksJSON() bool { + return c.SkipRisksJSONValue +} + +func (c *Config) GetSkipTechnicalAssetsJSON() bool { + return c.SkipTechnicalAssetsJSONValue +} + +func (c *Config) GetSkipStatsJSON() bool { + return c.SkipStatsJSONValue +} + +func (c *Config) GetSkipRisksExcel() bool { + return c.SkipRisksExcelValue +} + +func (c *Config) GetSkipTagsExcel() bool { + return c.SkipTagsExcelValue +} + +func (c *Config) GetSkipReportPDF() bool { + return c.SkipReportPDFValue +} + +func (c *Config) GetSkipReportADOC() bool { + return c.SkipReportADOCValue +} + func (c *Config) GetAttractiveness() Attractiveness { return c.AttractivenessValue } diff --git a/internal/threagile/create.go b/internal/threagile/create.go index 6f8c5330..750e8282 100644 --- a/internal/threagile/create.go +++ b/internal/threagile/create.go @@ -17,6 +17,8 @@ func (what *Threagile) initCreate() *Threagile { Short: "Create example threagile model", Long: "\n" + Logo + "\n\n" + fmt.Sprintf(VersionText, what.buildTimestamp) + "\n\njust create an example model named threagile-example-model.yaml in the output directory", RunE: func(cmd *cobra.Command, args []string) error { + what.processArgs(cmd, args) + appDir, err := cmd.Flags().GetString(appDirFlagName) if err != nil { cmd.Printf("Unable to read app-dir flag: %v", err) @@ -48,23 +50,25 @@ func (what *Threagile) initCreate() *Threagile { Short: "Create stub threagile model", Long: "\n" + Logo + "\n\n" + fmt.Sprintf(VersionText, what.buildTimestamp) + "\n\njust create a minimal stub model named threagile-stub-model.yaml in the output directory", RunE: func(cmd *cobra.Command, args []string) error { - cfg := what.readConfig(cmd, what.buildTimestamp) + what.processArgs(cmd, args) - err := examples.CreateStubModelFile(cfg.GetAppFolder(), cfg.GetOutputFolder(), InputFile) + err := examples.CreateStubModelFile(what.config.GetAppFolder(), what.config.GetOutputFolder(), InputFile) if err != nil { cmd.Printf("Unable to copy stub model: %v", err) return err } - if !what.flags.interactiveFlag { + if !what.config.GetInteractive() { cmd.Println(Logo + "\n\n" + fmt.Sprintf(VersionText, what.buildTimestamp)) } - cmd.Printf("A minimal stub model was created named threagile-stub-model.yaml in %q.\n", cfg.GetOutputFolder()) - if !what.flags.interactiveFlag { + + cmd.Printf("A minimal stub model was created named threagile-stub-model.yaml in %q.\n", what.config.GetOutputFolder()) + if !what.config.GetInteractive() { cmd.Println() cmd.Println(Examples) cmd.Println() } + return nil }, }) @@ -74,6 +78,8 @@ func (what *Threagile) initCreate() *Threagile { Short: "Create editing support", Long: "\n" + Logo + "\n\n" + fmt.Sprintf(VersionText, what.buildTimestamp) + "\n\njust create some editing support stuff in the output directory", RunE: func(cmd *cobra.Command, args []string) error { + what.processArgs(cmd, args) + appDir, err := cmd.Flags().GetString(appDirFlagName) if err != nil { cmd.Printf("Unable to read app-dir flag: %v", err) diff --git a/internal/threagile/execute.go b/internal/threagile/execute.go index 1867f203..39510f94 100644 --- a/internal/threagile/execute.go +++ b/internal/threagile/execute.go @@ -20,19 +20,21 @@ func (what *Threagile) initExecute() *Threagile { Short: "Execute model macro", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - cfg := what.readConfig(cmd, what.buildTimestamp) - progressReporter := DefaultProgressReporter{Verbose: cfg.GetVerbose()} + what.processArgs(cmd, args) - r, err := model.ReadAndAnalyzeModel(cfg, risks.GetBuiltInRiskRules(), progressReporter) + progressReporter := DefaultProgressReporter{Verbose: what.config.GetVerbose()} + + r, err := model.ReadAndAnalyzeModel(what.config, risks.GetBuiltInRiskRules(), progressReporter) if err != nil { return fmt.Errorf("unable to read and analyze model: %w", err) } macrosId := args[0] - err = macros.ExecuteModelMacro(r.ModelInput, cfg.GetInputFile(), r.ParsedModel, macrosId) + err = macros.ExecuteModelMacro(r.ModelInput, what.config.GetInputFile(), r.ParsedModel, macrosId) if err != nil { return fmt.Errorf("unable to execute model macro: %w", err) } + return nil }, }) diff --git a/internal/threagile/explain.go b/internal/threagile/explain.go index 2b5b8068..95747029 100644 --- a/internal/threagile/explain.go +++ b/internal/threagile/explain.go @@ -2,8 +2,6 @@ package threagile import ( "fmt" - "strings" - "github.com/spf13/cobra" "github.com/threagile/threagile/pkg/macros" "github.com/threagile/threagile/pkg/model" @@ -50,13 +48,14 @@ func (what *Threagile) initExplainNew() *Threagile { return what } -func (what *Threagile) explainRisk(cmd *cobra.Command, _ []string) error { - cfg := what.readConfig(cmd, what.buildTimestamp) - progressReporter := DefaultProgressReporter{Verbose: cfg.GetVerbose()} +func (what *Threagile) explainRisk(cmd *cobra.Command, args []string) error { + what.processArgs(cmd, args) + + progressReporter := DefaultProgressReporter{Verbose: what.config.GetVerbose()} // todo: reuse model if already loaded - result, runError := model.ReadAndAnalyzeModel(cfg, risks.GetBuiltInRiskRules(), progressReporter) + result, runError := model.ReadAndAnalyzeModel(what.config, risks.GetBuiltInRiskRules(), progressReporter) if runError != nil { cmd.Printf("Failed to read and analyze model: %v", runError) return runError @@ -68,20 +67,16 @@ func (what *Threagile) explainRisk(cmd *cobra.Command, _ []string) error { return fmt.Errorf("not implemented yet") } -func (what *Threagile) explainRules(cmd *cobra.Command, _ []string) error { +func (what *Threagile) explainRules(cmd *cobra.Command, args []string) error { + what.processArgs(cmd, args) + cmd.Println(Logo + "\n\n" + fmt.Sprintf(VersionText, what.buildTimestamp)) cmd.Println("Explanation for risk rules:") cmd.Println() cmd.Println("----------------------") cmd.Println("Custom risk rules:") cmd.Println("----------------------") - cfg := new(Config).Defaults(what.buildTimestamp) - configError := cfg.Load(what.flags.configFlag) - if configError != nil { - cmd.Printf("WARNING: failed to load config file %q: %v\n", what.flags.configFlag, configError) - } - - customRiskRules := model.LoadCustomRiskRules(cfg.GetPluginFolder(), strings.Split(what.flags.customRiskRulesPluginFlag, ","), DefaultProgressReporter{Verbose: what.flags.verboseFlag}) + customRiskRules := model.LoadCustomRiskRules(what.config.GetPluginFolder(), what.config.GetRiskRulePlugins(), DefaultProgressReporter{Verbose: what.config.GetVerbose()}) for _, rule := range customRiskRules { cmd.Printf("%v: %v\n", rule.Category().ID, rule.Category().Description) } @@ -99,6 +94,8 @@ func (what *Threagile) explainRules(cmd *cobra.Command, _ []string) error { } func (what *Threagile) explainMacros(cmd *cobra.Command, args []string) { + what.processArgs(cmd, args) + cmd.Println(Logo + "\n\n" + fmt.Sprintf(VersionText, what.buildTimestamp)) cmd.Println("Explanation for the model macros:") cmd.Println() @@ -122,12 +119,14 @@ func (what *Threagile) explainMacros(cmd *cobra.Command, args []string) { } func (what *Threagile) explainTypes(cmd *cobra.Command, args []string) { + what.processArgs(cmd, args) + cmd.Println(Logo + "\n\n" + fmt.Sprintf(VersionText, what.buildTimestamp)) fmt.Println("Explanation for the types:") cmd.Println() cmd.Println("The following types are available (can be extended for custom rules):") cmd.Println() - for name, values := range types.GetBuiltinTypeValues(what.readConfig(cmd, what.buildTimestamp)) { + for name, values := range types.GetBuiltinTypeValues(what.config) { cmd.Println(name) for _, candidate := range values { cmd.Printf("\t %v: %v\n", candidate, candidate.Explain()) diff --git a/internal/threagile/flags.go b/internal/threagile/flags.go index f95de19c..8080e318 100644 --- a/internal/threagile/flags.go +++ b/internal/threagile/flags.go @@ -7,28 +7,58 @@ package threagile const ( configFlagName = "config" + verboseFlagName = "Verbose" + verboseFlagShorthand = "v" + interactiveFlagName = "interactive" interactiveFlagShorthand = "i" - verboseFlagName = "Verbose" - verboseFlagShorthand = "v" + appDirFlagName = "app-dir" + pluginDirFlagName = "plugin-dir" + dataDirFlagName = "data-dir" + outputFlagName = "output" + serverDirFlagName = "server-dir" + tempDirFlagName = "temp-dir" + keyDirFlagName = "key-dir" - appDirFlagName = "app-dir" - outputFlagName = "output" - tempDirFlagName = "temp-dir" + inputFileFlagName = "model" + dataFlowDiagramPNGFileFlagName = "data-flow-diagram-png" + dataAssetDiagramPNGFileFlagName = "data-asset-diagram-png" + dataFlowDiagramDOTFileFlagName = "data-flow-diagram-dot" + dataAssetDiagramDOTFileFlagName = "data-asset-diagram-dot" + reportFileFlagName = "report" + risksExcelFileFlagName = "risks-excel" + tagsExcelFileFlagName = "tags-excel" + risksJsonFileFlagName = "risks-json" + technicalAssetsJsonFileFlagName = "technical-assets-json" + statsJsonFileFlagName = "stats-json" + templateFileNameFlagName = "background" + reportLogoImagePathFlagName = "reportLogoImagePath" + technologyFileFlagName = "technology" - serverDirFlagName = "server-dir" - serverPortFlagName = "server-port" + customRiskRulesPluginFlagName = "custom-risk-rules-plugin" + skipRiskRulesFlagName = "skip-risk-rules" + executeModelMacroFlagName = "execute-model-macro" - inputFileFlagName = "model" + serverModeFlagName = "server-mode" + serverPortFlagName = "server-port" + diagramDpiFlagName = "diagram-dpi" + graphvizDpiFlagName = "graphviz-dpi" + backupHistoryFilesToKeepFlagName = "backup-history-files-to-keep" - customRiskRulesPluginFlagName = "custom-risk-rules-plugin" - skipRiskRulesFlagName = "skip-risk-rules" + addModelTitleFlagName = "add-model-title" + keepDiagramSourceFilesFlagName = "keep-diagram-source-files" ignoreOrphanedRiskTrackingFlagName = "ignore-orphaned-risk-tracking" - diagramDpiFlagName = "diagram-dpi" - templateFileNameFlagName = "background" - reportLogoImagePathFlagName = "reportLogoImagePath" + skipDataFlowDiagramFlagName = "skip-data-flow-diagram" + skipDataAssetDiagramFlagName = "skip-data-asset-diagram" + skipRisksJSONFlagName = "skip-risks-json" + skipTechnicalAssetsJSONFlagName = "skip-technical-assets-json" + skipStatsJSONFlagName = "skip-stats-json" + skipRisksExcelFlagName = "skip-risks-excel" + skipTagsExcelFlagName = "skip-tags-excel" + skipReportPDFFlagName = "skip-report-pdf" + skipReportADOCFlagName = "skip-report-adoc" generateDataFlowDiagramFlagName = "generate-data-flow-diagram" generateDataAssetDiagramFlagName = "generate-data-asset-diagram" @@ -42,30 +72,19 @@ const ( ) type Flags struct { - configFlag string - verboseFlag bool - interactiveFlag bool - appDirFlag string - outputDirFlag string - tempDirFlag string - inputFileFlag string - serverPortFlag int - serverDirFlag string + Config - skipRiskRulesFlag string - customRiskRulesPluginFlag string - ignoreOrphanedRiskTrackingFlag bool - templateFileNameFlag string - reportLogoImagePathFlag string - diagramDpiFlag int + configFlag string + riskRulePluginsValue string + skipRiskRulesValue string - generateDataFlowDiagramFlag bool - generateDataAssetDiagramFlag bool - generateRisksJSONFlag bool - generateTechnicalAssetsJSONFlag bool - generateStatsJSONFlag bool - generateRisksExcelFlag bool - generateTagsExcelFlag bool - generateReportPDFFlag bool - generateReportADOCFlag bool + generateDataFlowDiagramFlag bool // deprecated + generateDataAssetDiagramFlag bool // deprecated + generateRisksJSONFlag bool // deprecated + generateTechnicalAssetsJSONFlag bool // deprecated + generateStatsJSONFlag bool // deprecated + generateRisksExcelFlag bool // deprecated + generateTagsExcelFlag bool // deprecated + generateReportPDFFlag bool // deprecated + generateReportADOCFlag bool // deprecated } diff --git a/internal/threagile/list.go b/internal/threagile/list.go index a5d54d75..94321b02 100644 --- a/internal/threagile/list.go +++ b/internal/threagile/list.go @@ -2,8 +2,6 @@ package threagile import ( "fmt" - "strings" - "github.com/spf13/cobra" "github.com/threagile/threagile/pkg/macros" "github.com/threagile/threagile/pkg/model" @@ -16,19 +14,15 @@ func (what *Threagile) initList() *Threagile { Use: ListRiskRulesCommand, Short: "Print available risk rules", RunE: func(cmd *cobra.Command, args []string) error { + what.processArgs(cmd, args) + cmd.Println(Logo + "\n\n" + fmt.Sprintf(VersionText, what.buildTimestamp)) cmd.Println("The following risk rules are available (can be extended via custom risk rules):") cmd.Println() cmd.Println("----------------------") cmd.Println("Custom risk rules:") cmd.Println("----------------------") - cfg := new(Config).Defaults(what.buildTimestamp) - configError := cfg.Load(what.flags.configFlag) - if configError != nil { - cmd.Printf("WARNING: failed to load config file %q: %v\n", what.flags.configFlag, configError) - } - - customRiskRules := model.LoadCustomRiskRules(cfg.GetPluginFolder(), strings.Split(what.flags.customRiskRulesPluginFlag, ","), DefaultProgressReporter{Verbose: what.flags.verboseFlag}) + customRiskRules := model.LoadCustomRiskRules(what.config.GetPluginFolder(), what.config.GetRiskRulePlugins(), DefaultProgressReporter{Verbose: what.config.GetVerbose()}) for id, customRule := range customRiskRules { cmd.Println(id, "-->", customRule.Category().Title, "--> with tags:", customRule.SupportedTags()) } @@ -49,6 +43,8 @@ func (what *Threagile) initList() *Threagile { Use: ListModelMacrosCommand, Short: "Print model macros", Run: func(cmd *cobra.Command, args []string) { + what.processArgs(cmd, args) + cmd.Println(Logo + "\n\n" + fmt.Sprintf(VersionText, what.buildTimestamp)) cmd.Println("The following model macros are available (can be extended via custom model macros):") cmd.Println() @@ -75,12 +71,14 @@ func (what *Threagile) initList() *Threagile { Use: ListTypesCommand, Short: "Print type information (enum values to be used in models)", Run: func(cmd *cobra.Command, args []string) { + what.processArgs(cmd, args) + cmd.Println(Logo + "\n\n" + fmt.Sprintf(VersionText, what.buildTimestamp)) cmd.Println() cmd.Println() cmd.Println("The following types are available (can be extended for custom rules):") cmd.Println() - for name, values := range types.GetBuiltinTypeValues(what.readConfig(cmd, what.buildTimestamp)) { + for name, values := range types.GetBuiltinTypeValues(what.config) { cmd.Println(fmt.Sprintf(" %v: %v", name, values)) } }, diff --git a/internal/threagile/print.go b/internal/threagile/print.go index f24885bb..fdb488b4 100644 --- a/internal/threagile/print.go +++ b/internal/threagile/print.go @@ -19,24 +19,30 @@ func (what *Threagile) initPrint() *Threagile { Use: PrintLicenseCommand, Short: "Print license information", RunE: func(cmd *cobra.Command, args []string) error { + what.processArgs(cmd, args) + appDir, err := cmd.Flags().GetString(appDirFlagName) if err != nil { cmd.Printf("Unable to read app-dir flag: %v", err) return err } + cmd.Println(Logo + "\n\n" + fmt.Sprintf(VersionText, what.buildTimestamp)) if appDir != filepath.Clean(appDir) { // TODO: do we need this check here? cmd.Printf("weird app folder %v", appDir) return fmt.Errorf("weird app folder") } + content, err := os.ReadFile(filepath.Clean(filepath.Join(appDir, "LICENSE.txt"))) if err != nil { cmd.Printf("Unable to read license file: %v", err) return err } + cmd.Print(string(content)) cmd.Println() + return nil }, }) diff --git a/internal/threagile/quit.go b/internal/threagile/quit.go index 559c296f..f3e17ee0 100644 --- a/internal/threagile/quit.go +++ b/internal/threagile/quit.go @@ -12,6 +12,7 @@ func (what *Threagile) initQuit() *Threagile { Short: "quit client", Aliases: []string{"exit", "bye", "x", "q"}, Run: func(cmd *cobra.Command, args []string) { + what.processArgs(cmd, args) os.Exit(0) }, CompletionOptions: cobra.CompletionOptions{ diff --git a/internal/threagile/root.go b/internal/threagile/root.go index a7c87538..4953a340 100644 --- a/internal/threagile/root.go +++ b/internal/threagile/root.go @@ -15,8 +15,6 @@ import ( "github.com/mattn/go-shellwords" "github.com/spf13/cobra" - "github.com/spf13/pflag" - "github.com/threagile/threagile/pkg/report" ) @@ -53,48 +51,88 @@ func (what *Threagile) initRoot() *Threagile { }, } - defaultConfig := new(Config).Defaults(what.buildTimestamp) - - what.rootCmd.PersistentFlags().StringVar(&what.flags.appDirFlag, appDirFlagName, defaultConfig.GetAppFolder(), "app folder") - what.rootCmd.PersistentFlags().StringVar(&what.flags.outputDirFlag, outputFlagName, defaultConfig.GetOutputFolder(), "output directory") - what.rootCmd.PersistentFlags().StringVar(&what.flags.tempDirFlag, tempDirFlagName, defaultConfig.GetTempFolder(), "temporary folder location") - - what.rootCmd.PersistentFlags().StringVar(&what.flags.inputFileFlag, inputFileFlagName, defaultConfig.GetInputFile(), "input model yaml file") + what.config = new(Config).Defaults(what.buildTimestamp) + return what.initFlags() +} - what.rootCmd.PersistentFlags().BoolVarP(&what.flags.interactiveFlag, interactiveFlagName, interactiveFlagShorthand, defaultConfig.GetInteractive(), "interactive mode") - what.rootCmd.PersistentFlags().BoolVarP(&what.flags.verboseFlag, verboseFlagName, verboseFlagShorthand, defaultConfig.GetVerbose(), "Verbose output") +func (what *Threagile) initFlags() *Threagile { + what.rootCmd.ResetFlags() what.rootCmd.PersistentFlags().StringVar(&what.flags.configFlag, configFlagName, "", "config file") - what.rootCmd.PersistentFlags().StringVar(&what.flags.customRiskRulesPluginFlag, customRiskRulesPluginFlagName, strings.Join(defaultConfig.GetRiskRulePlugins(), ","), "comma-separated list of plugins file names with custom risk rules to load") - what.rootCmd.PersistentFlags().IntVar(&what.flags.diagramDpiFlag, diagramDpiFlagName, defaultConfig.GetDiagramDPI(), "DPI used to render: maximum is "+fmt.Sprintf("%d", defaultConfig.GetMaxGraphvizDPI())+"") - what.rootCmd.PersistentFlags().StringVar(&what.flags.skipRiskRulesFlag, skipRiskRulesFlagName, strings.Join(defaultConfig.GetSkipRiskRules(), ","), "comma-separated list of risk rules (by their ID) to skip") - what.rootCmd.PersistentFlags().BoolVar(&what.flags.ignoreOrphanedRiskTrackingFlag, ignoreOrphanedRiskTrackingFlagName, defaultConfig.GetIgnoreOrphanedRiskTracking(), "ignore orphaned risk tracking (just log them) not matching a concrete risk") - what.rootCmd.PersistentFlags().StringVar(&what.flags.templateFileNameFlag, templateFileNameFlagName, defaultConfig.GetTemplateFilename(), "background pdf file") - what.rootCmd.PersistentFlags().StringVar(&what.flags.reportLogoImagePathFlag, reportLogoImagePathFlagName, defaultConfig.GetReportLogoImagePath(), "reportLogoImagePath") - - what.rootCmd.PersistentFlags().BoolVar(&what.flags.generateDataFlowDiagramFlag, generateDataFlowDiagramFlagName, true, "generate data flow diagram") - what.rootCmd.PersistentFlags().BoolVar(&what.flags.generateDataAssetDiagramFlag, generateDataAssetDiagramFlagName, true, "generate data asset diagram") - what.rootCmd.PersistentFlags().BoolVar(&what.flags.generateRisksJSONFlag, generateRisksJSONFlagName, true, "generate risks json") - what.rootCmd.PersistentFlags().BoolVar(&what.flags.generateTechnicalAssetsJSONFlag, generateTechnicalAssetsJSONFlagName, true, "generate technical assets json") - what.rootCmd.PersistentFlags().BoolVar(&what.flags.generateStatsJSONFlag, generateStatsJSONFlagName, true, "generate stats json") - what.rootCmd.PersistentFlags().BoolVar(&what.flags.generateRisksExcelFlag, generateRisksExcelFlagName, true, "generate risks excel") - what.rootCmd.PersistentFlags().BoolVar(&what.flags.generateTagsExcelFlag, generateTagsExcelFlagName, true, "generate tags excel") - what.rootCmd.PersistentFlags().BoolVar(&what.flags.generateReportPDFFlag, generateReportPDFFlagName, true, "generate report pdf, including diagrams") - what.rootCmd.PersistentFlags().BoolVar(&what.flags.generateReportADOCFlag, generateReportADOCFlagName, true, "generate report adoc, including diagrams") - - _ = what.rootCmd.PersistentFlags().Parse(os.Args[1:]) + what.rootCmd.PersistentFlags().BoolVarP(&what.flags.VerboseValue, verboseFlagName, verboseFlagShorthand, what.config.GetVerbose(), "Verbose output") + what.rootCmd.PersistentFlags().BoolVarP(&what.flags.InteractiveValue, interactiveFlagName, interactiveFlagShorthand, what.config.GetInteractive(), "interactive mode") + + what.rootCmd.PersistentFlags().StringVar(&what.flags.AppFolderValue, appDirFlagName, what.config.GetAppFolder(), "app folder") + what.rootCmd.PersistentFlags().StringVar(&what.flags.PluginFolderValue, pluginDirFlagName, what.config.GetPluginFolder(), "plugin directory") + what.rootCmd.PersistentFlags().StringVar(&what.flags.DataFolderValue, dataDirFlagName, what.config.GetDataFolder(), "data directory") + what.rootCmd.PersistentFlags().StringVar(&what.flags.OutputFolderValue, outputFlagName, what.config.GetOutputFolder(), "output directory") + what.rootCmd.PersistentFlags().StringVar(&what.flags.TempFolderValue, tempDirFlagName, what.config.GetTempFolder(), "temporary folder location") + what.rootCmd.PersistentFlags().StringVar(&what.flags.KeyFolderValue, keyDirFlagName, what.config.GetKeyFolder(), "key folder location") + + what.rootCmd.PersistentFlags().StringVar(&what.flags.InputFileValue, inputFileFlagName, what.config.GetInputFile(), "input model yaml file") + what.rootCmd.PersistentFlags().StringVar(&what.flags.DataFlowDiagramFilenamePNGValue, dataFlowDiagramPNGFileFlagName, what.config.GetDataFlowDiagramFilenamePNG(), "data flow diagram PNG file") + what.rootCmd.PersistentFlags().StringVar(&what.flags.DataAssetDiagramFilenamePNGValue, dataAssetDiagramPNGFileFlagName, what.config.GetDataAssetDiagramFilenamePNG(), "data asset diagram PNG file") + what.rootCmd.PersistentFlags().StringVar(&what.flags.DataFlowDiagramFilenameDOTValue, dataFlowDiagramDOTFileFlagName, what.config.GetDataFlowDiagramFilenameDOT(), "data flow diagram DOT file") + what.rootCmd.PersistentFlags().StringVar(&what.flags.DataAssetDiagramFilenameDOTValue, dataAssetDiagramDOTFileFlagName, what.config.GetDataAssetDiagramFilenameDOT(), "data asset diagram DOT file") + what.rootCmd.PersistentFlags().StringVar(&what.flags.ReportFilenameValue, reportFileFlagName, what.config.GetReportFilename(), "report file") + what.rootCmd.PersistentFlags().StringVar(&what.flags.ExcelRisksFilenameValue, risksExcelFileFlagName, what.config.GetExcelRisksFilename(), "risks Excel file") + what.rootCmd.PersistentFlags().StringVar(&what.flags.ExcelTagsFilenameValue, tagsExcelFileFlagName, what.config.GetExcelTagsFilename(), "tags Excel file") + what.rootCmd.PersistentFlags().StringVar(&what.flags.JsonRisksFilenameValue, risksJsonFileFlagName, what.config.GetJsonRisksFilename(), "risks JSON file") + what.rootCmd.PersistentFlags().StringVar(&what.flags.JsonTechnicalAssetsFilenameValue, technicalAssetsJsonFileFlagName, what.config.GetJsonTechnicalAssetsFilename(), "technical assets JSON file") + what.rootCmd.PersistentFlags().StringVar(&what.flags.JsonStatsFilenameValue, statsJsonFileFlagName, what.config.GetJsonStatsFilename(), "stats JSON file") + what.rootCmd.PersistentFlags().StringVar(&what.flags.TemplateFilenameValue, templateFileNameFlagName, what.config.GetTemplateFilename(), "template pdf file") + what.rootCmd.PersistentFlags().StringVar(&what.flags.ReportLogoImagePathValue, reportLogoImagePathFlagName, what.config.GetReportLogoImagePath(), "report logo image") + what.rootCmd.PersistentFlags().StringVar(&what.flags.TechnologyFilenameValue, technologyFileFlagName, what.config.GetTechnologyFilename(), "file name of additional technologies") + + what.rootCmd.PersistentFlags().StringVar(&what.flags.riskRulePluginsValue, customRiskRulesPluginFlagName, strings.Join(what.config.GetRiskRulePlugins(), ","), "comma-separated list of plugins file names with custom risk rules to load") + what.rootCmd.PersistentFlags().StringVar(&what.flags.skipRiskRulesValue, skipRiskRulesFlagName, strings.Join(what.config.GetSkipRiskRules(), ","), "comma-separated list of risk rules (by their ID) to skip") + what.rootCmd.PersistentFlags().StringVar(&what.flags.ExecuteModelMacroValue, executeModelMacroFlagName, what.config.GetExecuteModelMacro(), "macro to execute") + + // RiskExcelValue not available as flags + + what.rootCmd.PersistentFlags().IntVar(&what.flags.ServerPortValue, serverPortFlagName, what.config.GetServerPort(), "server port") + what.rootCmd.PersistentFlags().StringVar(&what.flags.ServerFolderValue, serverDirFlagName, what.config.GetDataFolder(), "base folder for server mode (default: "+DataDir+")") + what.rootCmd.PersistentFlags().IntVar(&what.flags.DiagramDPIValue, diagramDpiFlagName, what.config.GetDiagramDPI(), "DPI used to render: maximum is "+fmt.Sprintf("%d", what.config.GetMaxGraphvizDPI())+"") + // MaxGraphvizDPIValue not available as flags + what.rootCmd.PersistentFlags().IntVar(&what.flags.BackupHistoryFilesToKeepValue, backupHistoryFilesToKeepFlagName, what.config.GetBackupHistoryFilesToKeep(), "number of backup history files to keep") + + what.rootCmd.PersistentFlags().BoolVar(&what.flags.AddModelTitleValue, addModelTitleFlagName, what.config.GetAddModelTitle(), "add model title") + what.rootCmd.PersistentFlags().BoolVar(&what.flags.KeepDiagramSourceFilesValue, keepDiagramSourceFilesFlagName, what.config.GetKeepDiagramSourceFiles(), "keep diagram source files") + what.rootCmd.PersistentFlags().BoolVar(&what.flags.IgnoreOrphanedRiskTrackingValue, ignoreOrphanedRiskTrackingFlagName, what.config.GetIgnoreOrphanedRiskTracking(), "ignore orphaned risk tracking (just log them) not matching a concrete risk") + + what.rootCmd.PersistentFlags().BoolVar(&what.flags.SkipDataFlowDiagramValue, skipDataFlowDiagramFlagName, what.config.GetSkipDataFlowDiagram(), "skip generating data flow diagram") + what.rootCmd.PersistentFlags().BoolVar(&what.flags.SkipDataAssetDiagramValue, skipDataAssetDiagramFlagName, what.config.GetSkipDataAssetDiagram(), "skip generating data asset diagram") + what.rootCmd.PersistentFlags().BoolVar(&what.flags.SkipRisksJSONValue, skipRisksJSONFlagName, what.config.GetSkipRisksJSON(), "skip generating risks json") + what.rootCmd.PersistentFlags().BoolVar(&what.flags.SkipTechnicalAssetsJSONValue, skipTechnicalAssetsJSONFlagName, what.config.GetSkipTechnicalAssetsJSON(), "skip generating technical assets json") + what.rootCmd.PersistentFlags().BoolVar(&what.flags.SkipStatsJSONValue, skipStatsJSONFlagName, what.config.GetSkipStatsJSON(), "skip generating stats json") + what.rootCmd.PersistentFlags().BoolVar(&what.flags.SkipRisksExcelValue, skipRisksExcelFlagName, what.config.GetSkipRisksExcel(), "skip generating risks excel") + what.rootCmd.PersistentFlags().BoolVar(&what.flags.SkipTagsExcelValue, skipTagsExcelFlagName, what.config.GetSkipTagsExcel(), "skip generating tags excel") + what.rootCmd.PersistentFlags().BoolVar(&what.flags.SkipReportPDFValue, skipReportPDFFlagName, what.config.GetSkipReportPDF(), "skip generating report pdf, including diagrams") + what.rootCmd.PersistentFlags().BoolVar(&what.flags.SkipReportADOCValue, skipReportADOCFlagName, what.config.GetSkipReportADOC(), "skip generating report adoc, including diagrams") + + what.rootCmd.PersistentFlags().BoolVar(&what.flags.generateDataFlowDiagramFlag, generateDataFlowDiagramFlagName, !what.config.GetSkipDataFlowDiagram(), "(deprecated) generate generating data flow diagram") + what.rootCmd.PersistentFlags().BoolVar(&what.flags.generateDataAssetDiagramFlag, generateDataAssetDiagramFlagName, !what.config.GetSkipDataAssetDiagram(), "(deprecated) generate generating data asset diagram") + what.rootCmd.PersistentFlags().BoolVar(&what.flags.generateRisksJSONFlag, generateRisksJSONFlagName, !what.config.GetSkipRisksJSON(), "(deprecated) generate generating risks json") + what.rootCmd.PersistentFlags().BoolVar(&what.flags.generateTechnicalAssetsJSONFlag, generateTechnicalAssetsJSONFlagName, !what.config.GetSkipTechnicalAssetsJSON(), "(deprecated) generate generating technical assets json") + what.rootCmd.PersistentFlags().BoolVar(&what.flags.generateStatsJSONFlag, generateStatsJSONFlagName, !what.config.GetSkipStatsJSON(), "(deprecated) generate generating stats json") + what.rootCmd.PersistentFlags().BoolVar(&what.flags.generateRisksExcelFlag, generateRisksExcelFlagName, !what.config.GetSkipRisksExcel(), "(deprecated) generate generating risks excel") + what.rootCmd.PersistentFlags().BoolVar(&what.flags.generateTagsExcelFlag, generateTagsExcelFlagName, !what.config.GetSkipTagsExcel(), "(deprecated) generate generating tags excel") + what.rootCmd.PersistentFlags().BoolVar(&what.flags.generateReportPDFFlag, generateReportPDFFlagName, !what.config.GetSkipReportPDF(), "(deprecated) generate generating report pdf, including diagrams") + what.rootCmd.PersistentFlags().BoolVar(&what.flags.generateReportADOCFlag, generateReportADOCFlagName, !what.config.GetSkipReportADOC(), "(deprecated) generate generating report adoc, including diagrams") + + // AttractivenessValue not available as flags + // ReportConfigurationValue not available as flags return what } -func (what *Threagile) run(cmd *cobra.Command, _ []string) { - if !what.flags.interactiveFlag { - cfg := what.readConfig(cmd, what.buildTimestamp) - if !cfg.GetInteractive() { - cmd.Println("Please add the --interactive flag to run in interactive mode.") - return - } +func (what *Threagile) run(thisCmd *cobra.Command, args []string) { + what.processArgs(thisCmd, args) + + if !what.config.GetInteractive() { + what.rootCmd.Println("Please add the --interactive flag to run in interactive mode.") + return } completer := readline.NewPrefixCompleter() @@ -104,7 +142,7 @@ func (what *Threagile) run(cmd *cobra.Command, _ []string) { dir, homeError := os.UserHomeDir() if homeError != nil { - cmd.Println("Error, please report bug at https://github.com/Threagile/threagile. Unable to find home directory: " + homeError.Error()) + what.rootCmd.Println("Error, please report bug at https://github.com/Threagile/threagile. Unable to find home directory: " + homeError.Error()) return } @@ -119,7 +157,7 @@ func (what *Threagile) run(cmd *cobra.Command, _ []string) { }) if readlineError != nil { - cmd.Println("Error, please report bug at https://github.com/Threagile/threagile. Unable to initialize readline: " + readlineError.Error()) + what.rootCmd.Println("Error, please report bug at https://github.com/Threagile/threagile. Unable to initialize readline: " + readlineError.Error()) return } @@ -131,7 +169,7 @@ func (what *Threagile) run(cmd *cobra.Command, _ []string) { return } if readError != nil { - cmd.Println("Error, please report bug at https://github.com/Threagile/threagile. Unable to read line: " + readError.Error()) + what.rootCmd.Println("Error, please report bug at https://github.com/Threagile/threagile. Unable to read line: " + readError.Error()) return } @@ -141,24 +179,24 @@ func (what *Threagile) run(cmd *cobra.Command, _ []string) { params, parseError := shellwords.Parse(line) if parseError != nil { - cmd.Printf("failed to parse command line: %s", parseError.Error()) + what.rootCmd.Printf("failed to parse command line: %s", parseError.Error()) continue } cmd, args, findError := what.rootCmd.Find(params) if findError != nil { - cmd.Printf("failed to find command: %s", findError.Error()) + what.rootCmd.Printf("failed to find command: %s", findError.Error()) continue } if cmd == nil || cmd == what.rootCmd { - cmd.Println("failed to find command") + what.rootCmd.Println("failed to find command") continue } flagsError := cmd.ParseFlags(args) if flagsError != nil { - cmd.Printf("invalid flags: %s", flagsError.Error()) + what.rootCmd.Printf("invalid flags: %s", flagsError.Error()) continue } @@ -180,8 +218,9 @@ func (what *Threagile) run(cmd *cobra.Command, _ []string) { if cmd.RunE != nil { runError := cmd.RunE(cmd, args) if runError != nil { - cmd.Printf("error: %v \n", runError) + what.rootCmd.Printf("error: %v \n", runError) } + continue } @@ -213,76 +252,263 @@ func (what *Threagile) usage(cmd *cobra.Command) string { func (what *Threagile) readCommands() *report.GenerateCommands { commands := new(report.GenerateCommands).Defaults() - commands.DataFlowDiagram = what.flags.generateDataFlowDiagramFlag - commands.DataAssetDiagram = what.flags.generateDataAssetDiagramFlag - commands.RisksJSON = what.flags.generateRisksJSONFlag - commands.StatsJSON = what.flags.generateStatsJSONFlag - commands.TechnicalAssetsJSON = what.flags.generateTechnicalAssetsJSONFlag - commands.RisksExcel = what.flags.generateRisksExcelFlag - commands.TagsExcel = what.flags.generateTagsExcelFlag - commands.ReportPDF = what.flags.generateReportPDFFlag - commands.ReportADOC = what.flags.generateReportADOCFlag + commands.DataFlowDiagram = !what.flags.SkipDataFlowDiagramValue + commands.DataAssetDiagram = !what.flags.SkipDataAssetDiagramValue + commands.RisksJSON = !what.flags.SkipRisksJSONValue + commands.StatsJSON = !what.flags.SkipStatsJSONValue + commands.TechnicalAssetsJSON = !what.flags.SkipTechnicalAssetsJSONValue + commands.RisksExcel = !what.flags.SkipRisksExcelValue + commands.TagsExcel = !what.flags.SkipTagsExcelValue + commands.ReportPDF = !what.flags.SkipReportPDFValue + commands.ReportADOC = !what.flags.SkipReportADOCValue return commands } -func (what *Threagile) readConfig(cmd *cobra.Command, buildTimestamp string) *Config { - cfg := new(Config).Defaults(buildTimestamp) - configError := cfg.Load(what.flags.configFlag) - if configError != nil { - cmd.Printf("WARNING: failed to load config file %q: %v\n", what.flags.configFlag, configError) +func (what *Threagile) processSystemArgs(cmd *cobra.Command) *Threagile { + what.config.InteractiveValue = what.processArgs(cmd, os.Args[1:]) + return what +} + +func (what *Threagile) processArgs(cmd *cobra.Command, args []string) bool { + _ = cmd.PersistentFlags().Parse(args) + + if what.isFlagOverridden(cmd, configFlagName) { + configError := what.config.Load(what.flags.configFlag) + if configError != nil { + what.rootCmd.Printf("WARNING: failed to load config file %q: %v\n", what.flags.configFlag, configError) + } + } + + if what.isFlagOverridden(cmd, verboseFlagName) { + what.config.VerboseValue = what.flags.VerboseValue + } + + interactive := what.config.GetInteractive() + if what.isFlagOverridden(cmd, interactiveFlagName) { + interactive = what.flags.InteractiveValue } - flags := cmd.Flags() - if isFlagOverridden(flags, serverPortFlagName) { - cfg.ServerPortValue = what.flags.serverPortFlag + if what.isFlagOverridden(cmd, appDirFlagName) { + what.config.AppFolderValue = what.config.CleanPath(what.flags.AppFolderValue) } - if isFlagOverridden(flags, serverDirFlagName) { - cfg.ServerFolderValue = cfg.CleanPath(what.flags.serverDirFlag) + + if what.isFlagOverridden(cmd, pluginDirFlagName) { + what.config.PluginFolderValue = what.config.CleanPath(what.flags.PluginFolderValue) + } + + if what.isFlagOverridden(cmd, dataDirFlagName) { + what.config.DataFolderValue = what.config.CleanPath(what.flags.DataFolderValue) } - if isFlagOverridden(flags, appDirFlagName) { - cfg.AppFolderValue = cfg.CleanPath(what.flags.appDirFlag) + if what.isFlagOverridden(cmd, outputFlagName) { + what.config.OutputFolderValue = what.config.CleanPath(what.flags.OutputFolderValue) } - if isFlagOverridden(flags, outputFlagName) { - cfg.OutputFolderValue = cfg.CleanPath(what.flags.outputDirFlag) + + if what.isFlagOverridden(cmd, serverDirFlagName) { + what.config.ServerFolderValue = what.config.CleanPath(what.flags.ServerFolderValue) } - if isFlagOverridden(flags, tempDirFlagName) { - cfg.TempFolderValue = cfg.CleanPath(what.flags.tempDirFlag) + + if what.isFlagOverridden(cmd, tempDirFlagName) { + what.config.TempFolderValue = what.config.CleanPath(what.flags.TempFolderValue) } - if isFlagOverridden(flags, verboseFlagName) { - cfg.VerboseValue = what.flags.verboseFlag + if what.isFlagOverridden(cmd, keyDirFlagName) { + what.config.KeyFolderValue = what.config.CleanPath(what.flags.KeyFolderValue) } - if isFlagOverridden(flags, inputFileFlagName) { - cfg.InputFileValue = cfg.CleanPath(what.flags.inputFileFlag) + if what.isFlagOverridden(cmd, inputFileFlagName) { + what.config.InputFileValue = what.config.CleanPath(what.flags.InputFileValue) } - if isFlagOverridden(flags, customRiskRulesPluginFlagName) { - cfg.SetRiskRulePlugins(strings.Split(what.flags.customRiskRulesPluginFlag, ",")) + if what.isFlagOverridden(cmd, dataFlowDiagramPNGFileFlagName) { + what.config.DataFlowDiagramFilenamePNGValue = what.config.CleanPath(what.flags.DataFlowDiagramFilenamePNGValue) } - if isFlagOverridden(flags, skipRiskRulesFlagName) { - cfg.SkipRiskRulesValue = strings.Split(what.flags.skipRiskRulesFlag, ",") + + if what.isFlagOverridden(cmd, dataAssetDiagramPNGFileFlagName) { + what.config.DataAssetDiagramFilenamePNGValue = what.config.CleanPath(what.flags.DataAssetDiagramFilenamePNGValue) } - if isFlagOverridden(flags, ignoreOrphanedRiskTrackingFlagName) { - cfg.IgnoreOrphanedRiskTrackingValue = what.flags.ignoreOrphanedRiskTrackingFlag + + if what.isFlagOverridden(cmd, dataFlowDiagramDOTFileFlagName) { + what.config.DataFlowDiagramFilenameDOTValue = what.config.CleanPath(what.flags.DataFlowDiagramFilenameDOTValue) } - if isFlagOverridden(flags, diagramDpiFlagName) { - cfg.DiagramDPIValue = what.flags.diagramDpiFlag + + if what.isFlagOverridden(cmd, dataAssetDiagramDOTFileFlagName) { + what.config.DataAssetDiagramFilenameDOTValue = what.config.CleanPath(what.flags.DataAssetDiagramFilenameDOTValue) } - if isFlagOverridden(flags, templateFileNameFlagName) { - cfg.TemplateFilenameValue = what.flags.templateFileNameFlag + + if what.isFlagOverridden(cmd, reportFileFlagName) { + what.config.ReportFilenameValue = what.config.CleanPath(what.flags.ReportFilenameValue) } - if isFlagOverridden(flags, reportLogoImagePathFlagName) { - cfg.ReportLogoImagePathValue = what.flags.reportLogoImagePathFlag + + if what.isFlagOverridden(cmd, risksExcelFileFlagName) { + what.config.ExcelRisksFilenameValue = what.config.CleanPath(what.flags.ExcelRisksFilenameValue) + } + + if what.isFlagOverridden(cmd, tagsExcelFileFlagName) { + what.config.ExcelTagsFilenameValue = what.config.CleanPath(what.flags.ExcelTagsFilenameValue) + } + + if what.isFlagOverridden(cmd, risksJsonFileFlagName) { + what.config.JsonRisksFilenameValue = what.config.CleanPath(what.flags.JsonRisksFilenameValue) + } + + if what.isFlagOverridden(cmd, technicalAssetsJsonFileFlagName) { + what.config.JsonTechnicalAssetsFilenameValue = what.config.CleanPath(what.flags.JsonTechnicalAssetsFilenameValue) + } + + if what.isFlagOverridden(cmd, statsJsonFileFlagName) { + what.config.JsonStatsFilenameValue = what.config.CleanPath(what.flags.JsonStatsFilenameValue) + } + + if what.isFlagOverridden(cmd, templateFileNameFlagName) { + what.config.TemplateFilenameValue = what.flags.TemplateFilenameValue + } + + if what.isFlagOverridden(cmd, reportLogoImagePathFlagName) { + what.config.ReportLogoImagePathValue = what.flags.ReportLogoImagePathValue + } + + if what.isFlagOverridden(cmd, technologyFileFlagName) { + what.config.TechnologyFilenameValue = what.flags.TechnologyFilenameValue + } + + if what.isFlagOverridden(cmd, customRiskRulesPluginFlagName) { + what.config.RiskRulePluginsValue = strings.Split(what.flags.riskRulePluginsValue, ",") + } + + if what.isFlagOverridden(cmd, skipRiskRulesFlagName) { + what.config.SkipRiskRulesValue = strings.Split(what.flags.skipRiskRulesValue, ",") + } + + if what.isFlagOverridden(cmd, executeModelMacroFlagName) { + what.config.ExecuteModelMacroValue = what.flags.ExecuteModelMacroValue + } + + // RiskExcelValue not available as flags + + if what.isFlagOverridden(cmd, serverModeFlagName) { + what.config.ServerModeValue = what.flags.ServerModeValue + } + + if what.isFlagOverridden(cmd, serverPortFlagName) { + what.config.ServerPortValue = what.flags.ServerPortValue + } + + if what.isFlagOverridden(cmd, diagramDpiFlagName) { + what.config.DiagramDPIValue = what.flags.DiagramDPIValue } - return cfg + + if what.isFlagOverridden(cmd, graphvizDpiFlagName) { + what.config.GraphvizDPIValue = what.flags.GraphvizDPIValue + } + + // MaxGraphvizDPIValue not available as flags + + if what.isFlagOverridden(cmd, backupHistoryFilesToKeepFlagName) { + what.config.BackupHistoryFilesToKeepValue = what.flags.BackupHistoryFilesToKeepValue + } + + if what.isFlagOverridden(cmd, addModelTitleFlagName) { + what.config.AddModelTitleValue = what.flags.AddModelTitleValue + } + + if what.isFlagOverridden(cmd, keepDiagramSourceFilesFlagName) { + what.config.KeepDiagramSourceFilesValue = what.flags.KeepDiagramSourceFilesValue + } + + if what.isFlagOverridden(cmd, ignoreOrphanedRiskTrackingFlagName) { + what.config.IgnoreOrphanedRiskTrackingValue = what.flags.IgnoreOrphanedRiskTrackingValue + } + + if what.isFlagOverridden(cmd, skipDataFlowDiagramFlagName) { + what.config.SkipDataFlowDiagramValue = what.flags.SkipDataFlowDiagramValue + } + + if what.isFlagOverridden(cmd, skipDataAssetDiagramFlagName) { + what.config.SkipDataAssetDiagramValue = what.flags.SkipDataAssetDiagramValue + } + + if what.isFlagOverridden(cmd, skipRisksJSONFlagName) { + what.config.SkipRisksJSONValue = what.flags.SkipRisksJSONValue + } + + if what.isFlagOverridden(cmd, skipTechnicalAssetsJSONFlagName) { + what.config.SkipTechnicalAssetsJSONValue = what.flags.SkipTechnicalAssetsJSONValue + } + + if what.isFlagOverridden(cmd, skipStatsJSONFlagName) { + what.config.SkipStatsJSONValue = what.flags.SkipStatsJSONValue + } + + if what.isFlagOverridden(cmd, skipRisksExcelFlagName) { + what.config.SkipRisksExcelValue = what.flags.SkipRisksExcelValue + } + + if what.isFlagOverridden(cmd, skipTagsExcelFlagName) { + what.config.SkipTagsExcelValue = what.flags.SkipTagsExcelValue + } + + if what.isFlagOverridden(cmd, skipReportPDFFlagName) { + what.config.SkipReportPDFValue = what.flags.SkipReportPDFValue + } + + if what.isFlagOverridden(cmd, skipReportADOCFlagName) { + what.config.SkipReportADOCValue = what.flags.SkipReportADOCValue + } + + if what.isFlagOverridden(cmd, generateDataFlowDiagramFlagName) { + what.config.SkipDataFlowDiagramValue = !what.flags.generateDataFlowDiagramFlag + } + + if what.isFlagOverridden(cmd, generateDataAssetDiagramFlagName) { + what.config.SkipDataAssetDiagramValue = !what.flags.generateDataAssetDiagramFlag + } + + if what.isFlagOverridden(cmd, generateRisksJSONFlagName) { + what.config.SkipRisksJSONValue = !what.flags.generateRisksJSONFlag + } + + if what.isFlagOverridden(cmd, generateTechnicalAssetsJSONFlagName) { + what.config.SkipTechnicalAssetsJSONValue = !what.flags.generateTechnicalAssetsJSONFlag + } + + if what.isFlagOverridden(cmd, generateStatsJSONFlagName) { + what.config.SkipStatsJSONValue = !what.flags.generateStatsJSONFlag + } + + if what.isFlagOverridden(cmd, generateRisksExcelFlagName) { + what.config.SkipRisksExcelValue = !what.flags.generateRisksExcelFlag + } + + if what.isFlagOverridden(cmd, generateTagsExcelFlagName) { + what.config.SkipTagsExcelValue = !what.flags.generateTagsExcelFlag + } + + if what.isFlagOverridden(cmd, generateReportPDFFlagName) { + what.config.SkipReportPDFValue = !what.flags.generateReportPDFFlag + } + + if what.isFlagOverridden(cmd, generateReportADOCFlagName) { + what.config.SkipReportADOCValue = !what.flags.generateReportADOCFlag + } + + // AttractivenessValue not available as flags + // ReportConfigurationValue not available as flags + + what.initFlags() + + return interactive } -func isFlagOverridden(flags *pflag.FlagSet, flagName string) bool { - flag := flags.Lookup(flagName) +func (what *Threagile) isFlagOverridden(cmd *cobra.Command, flagName string) bool { + if cmd == nil { + return false + } + + flag := cmd.PersistentFlags().Lookup(flagName) if flag == nil { return false } + return flag.Changed } diff --git a/internal/threagile/server.go b/internal/threagile/server.go index feaca668..3f936348 100644 --- a/internal/threagile/server.go +++ b/internal/threagile/server.go @@ -7,27 +7,30 @@ import ( ) func (what *Threagile) initServer() *Threagile { - defaultConfig := new(Config).Defaults(what.buildTimestamp) - serverCmd := &cobra.Command{ Use: "server", Short: "Run server", RunE: func(cmd *cobra.Command, args []string) error { - cfg := what.readConfig(cmd, what.buildTimestamp) - cfg.SetServerMode(true) - serverError := cfg.CheckServerFolder() - if serverError != nil { - return serverError - } - server.RunServer(cfg, risks.GetBuiltInRiskRules()) - return nil + what.processArgs(cmd, args) + return what.runServer() }, } - serverCmd.PersistentFlags().IntVar(&what.flags.serverPortFlag, serverPortFlagName, defaultConfig.GetServerPort(), "server port") - serverCmd.PersistentFlags().StringVar(&what.flags.serverDirFlag, serverDirFlagName, defaultConfig.GetDataFolder(), "base folder for server mode (default: "+DataDir+")") + serverCmd.PersistentFlags().IntVar(&what.flags.ServerPortValue, serverPortFlagName, what.config.GetServerPort(), "server port") + serverCmd.PersistentFlags().StringVar(&what.flags.ServerFolderValue, serverDirFlagName, what.config.GetDataFolder(), "base folder for server mode (default: "+DataDir+")") what.rootCmd.AddCommand(serverCmd) return what } + +func (what *Threagile) runServer() error { + what.config.SetServerMode(true) + serverError := what.config.CheckServerFolder() + if serverError != nil { + return serverError + } + + server.RunServer(what.config, risks.GetBuiltInRiskRules()) + return nil +} diff --git a/internal/threagile/threagile.go b/internal/threagile/threagile.go index 72336ef0..dc2aef4f 100644 --- a/internal/threagile/threagile.go +++ b/internal/threagile/threagile.go @@ -8,6 +8,7 @@ import ( type Threagile struct { flags Flags + config *Config rootCmd *cobra.Command buildTimestamp string } @@ -19,13 +20,15 @@ func (what *Threagile) Execute() { os.Exit(1) } - cfg := what.readConfig(what.rootCmd, what.buildTimestamp) - if what.flags.interactiveFlag || cfg.GetInteractive() { + if what.config.GetServerMode() { + serverError := what.runServer() + what.rootCmd.Println(serverError) + } else if what.config.GetInteractive() { what.run(what.rootCmd, nil) } } func (what *Threagile) Init(buildTimestamp string) *Threagile { what.buildTimestamp = buildTimestamp - return what.initRoot().initAnalyze().initCreate().initExecute().initExplain().initList().initPrint().initQuit().initServer().initVersion() + return what.initRoot().initAnalyze().initCreate().initExecute().initExplain().initList().initPrint().initQuit().initServer().initVersion().processSystemArgs(what.rootCmd) }