From a3df8e21cfd4e17ccf80313c5b57e7e55d230a88 Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 5 Jun 2024 15:50:04 +0200 Subject: [PATCH 01/56] Move portbase into monorepo --- Earthfile | 2 +- assets/icons_default.go | 2 +- base/.gitignore | 8 + base/README.md | 157 ++++ base/api/api_bridge.go | 173 ++++ base/api/auth_wrapper.go | 30 + base/api/authentication.go | 601 +++++++++++++ base/api/authentication_test.go | 194 +++++ base/api/client/api.go | 57 ++ base/api/client/client.go | 240 +++++ base/api/client/const.go | 28 + base/api/client/message.go | 95 ++ base/api/client/websocket.go | 121 +++ base/api/config.go | 91 ++ base/api/database.go | 698 +++++++++++++++ base/api/doc.go | 10 + base/api/endpoints.go | 532 +++++++++++ base/api/endpoints_config.go | 24 + base/api/endpoints_debug.go | 256 ++++++ base/api/endpoints_meta.go | 140 +++ base/api/endpoints_modules.go | 56 ++ base/api/endpoints_test.go | 161 ++++ base/api/enriched-response.go | 68 ++ base/api/main.go | 88 ++ base/api/main_test.go | 56 ++ base/api/modules.go | 49 ++ base/api/request.go | 60 ++ base/api/router.go | 334 +++++++ base/api/testclient/root/index.html | 49 ++ base/api/testclient/serve.go | 11 + base/apprise/notify.go | 167 ++++ base/config/basic_config.go | 113 +++ base/config/database.go | 169 ++++ base/config/doc.go | 2 + base/config/expertise.go | 104 +++ base/config/get-safe.go | 112 +++ base/config/get.go | 174 ++++ base/config/get_test.go | 368 ++++++++ base/config/main.go | 141 +++ base/config/option.go | 418 +++++++++ base/config/persistence.go | 234 +++++ base/config/persistence_test.go | 97 +++ base/config/perspective.go | 133 +++ base/config/registry.go | 106 +++ base/config/registry_test.go | 49 ++ base/config/release.go | 101 +++ base/config/set.go | 235 +++++ base/config/set_test.go | 193 ++++ base/config/validate.go | 239 +++++ base/config/validity.go | 32 + base/container/container.go | 368 ++++++++ base/container/container_test.go | 208 +++++ base/container/doc.go | 26 + base/container/serialization.go | 21 + base/database/accessor/accessor-json-bytes.go | 116 +++ .../database/accessor/accessor-json-string.go | 140 +++ base/database/accessor/accessor-struct.go | 169 ++++ base/database/accessor/accessor.go | 18 + base/database/accessor/accessor_test.go | 291 +++++++ base/database/boilerplate_test.go | 65 ++ base/database/controller.go | 355 ++++++++ base/database/controllers.go | 106 +++ base/database/database.go | 26 + base/database/database_test.go | 303 +++++++ base/database/dbmodule/db.go | 50 ++ base/database/dbmodule/maintenance.go | 31 + base/database/doc.go | 62 ++ base/database/errors.go | 14 + base/database/hook.go | 91 ++ base/database/hookbase.go | 38 + base/database/interface.go | 585 +++++++++++++ base/database/interface_cache.go | 227 +++++ base/database/interface_cache_test.go | 156 ++++ base/database/iterator/iterator.go | 54 ++ base/database/main.go | 85 ++ base/database/maintenance.go | 64 ++ base/database/migration/error.go | 58 ++ base/database/migration/migration.go | 220 +++++ base/database/query/README.md | 55 ++ base/database/query/condition-and.go | 46 + base/database/query/condition-bool.go | 69 ++ base/database/query/condition-error.go | 27 + base/database/query/condition-exists.go | 35 + base/database/query/condition-float.go | 97 +++ base/database/query/condition-int.go | 93 ++ base/database/query/condition-not.go | 36 + base/database/query/condition-or.go | 46 + base/database/query/condition-regex.go | 63 ++ base/database/query/condition-string.go | 62 ++ base/database/query/condition-stringslice.go | 69 ++ base/database/query/condition.go | 71 ++ base/database/query/condition_test.go | 86 ++ base/database/query/operators.go | 53 ++ base/database/query/operators_test.go | 11 + base/database/query/parser.go | 350 ++++++++ base/database/query/parser_test.go | 177 ++++ base/database/query/query.go | 170 ++++ base/database/query/query_test.go | 113 +++ base/database/record/base.go | 156 ++++ base/database/record/base_test.go | 13 + base/database/record/key.go | 14 + base/database/record/meta-bench_test.go | 348 ++++++++ base/database/record/meta-gencode.go | 145 +++ base/database/record/meta-gencode_test.go | 35 + base/database/record/meta.colf | 10 + base/database/record/meta.gencode | 8 + base/database/record/meta.go | 129 +++ base/database/record/record.go | 32 + base/database/record/record_test.go | 10 + base/database/record/wrapper.go | 160 ++++ base/database/record/wrapper_test.go | 57 ++ base/database/registry.go | 168 ++++ base/database/storage/badger/badger.go | 231 +++++ base/database/storage/badger/badger_test.go | 148 ++++ base/database/storage/bbolt/bbolt.go | 427 +++++++++ base/database/storage/bbolt/bbolt_test.go | 206 +++++ base/database/storage/errors.go | 8 + base/database/storage/fstree/fstree.go | 302 +++++++ base/database/storage/fstree/fstree_test.go | 6 + base/database/storage/hashmap/map.go | 216 +++++ base/database/storage/hashmap/map_test.go | 145 +++ base/database/storage/injectbase.go | 60 ++ base/database/storage/interface.go | 48 + base/database/storage/sinkhole/sinkhole.go | 111 +++ base/database/storage/storages.go | 47 + base/database/subscription.go | 35 + base/dataroot/root.go | 25 + base/formats/dsd/compression.go | 103 +++ base/formats/dsd/dsd.go | 160 ++++ base/formats/dsd/dsd_test.go | 327 +++++++ base/formats/dsd/format.go | 73 ++ base/formats/dsd/gencode_test.go | 824 ++++++++++++++++++ base/formats/dsd/http.go | 178 ++++ base/formats/dsd/http_test.go | 45 + base/formats/dsd/interfaces.go | 9 + base/formats/dsd/tests.gencode | 23 + base/formats/varint/helpers.go | 48 + base/formats/varint/varint.go | 97 +++ base/formats/varint/varint_test.go | 141 +++ base/info/module/flags.go | 38 + base/info/version.go | 169 ++++ base/log/flags.go | 13 + base/log/formatting.go | 97 +++ base/log/formatting_unix.go | 44 + base/log/formatting_windows.go | 56 ++ base/log/input.go | 219 +++++ base/log/logging.go | 243 ++++++ base/log/logging_test.go | 64 ++ base/log/output.go | 289 ++++++ base/log/trace.go | 280 ++++++ base/log/trace_test.go | 35 + base/metrics/api.go | 158 ++++ base/metrics/config.go | 108 +++ base/metrics/metric.go | 165 ++++ base/metrics/metric_counter.go | 49 ++ base/metrics/metric_counter_fetching.go | 62 ++ base/metrics/metric_export.go | 89 ++ base/metrics/metric_gauge.go | 46 + base/metrics/metric_histogram.go | 41 + base/metrics/metrics_host.go | 263 ++++++ base/metrics/metrics_info.go | 45 + base/metrics/metrics_logs.go | 49 ++ base/metrics/metrics_runtime.go | 98 +++ base/metrics/module.go | 171 ++++ base/metrics/persistence.go | 153 ++++ base/metrics/testdata/.gitignore | 1 + base/metrics/testdata/README.md | 4 + base/metrics/testdata/docker-compose.yml | 36 + base/metrics/testdata/grafana/config.ini | 10 + .../grafana/dashboards/portmaster.yml | 11 + .../grafana/datasources/datasource.yml | 8 + base/notifications/cleaner.go | 51 ++ base/notifications/config.go | 32 + base/notifications/database.go | 239 +++++ base/notifications/doc.go | 26 + base/notifications/module-mirror.go | 115 +++ base/notifications/module.go | 66 ++ base/notifications/notification.go | 523 +++++++++++ base/rng/doc.go | 9 + base/rng/entropy.go | 124 +++ base/rng/entropy_test.go | 73 ++ base/rng/fullfeed.go | 43 + base/rng/fullfeed_test.go | 15 + base/rng/get.go | 94 ++ base/rng/get_test.go | 41 + base/rng/osfeeder.go | 35 + base/rng/rng.go | 81 ++ base/rng/rng_test.go | 50 ++ base/rng/test/.gitignore | 4 + base/rng/test/README.md | 279 ++++++ base/rng/test/main.go | 191 ++++ base/rng/tickfeeder.go | 75 ++ base/runtime/module.go | 44 + base/runtime/modules_integration.go | 71 ++ base/runtime/provider.go | 74 ++ base/runtime/registry.go | 335 +++++++ base/runtime/registry_test.go | 157 ++++ base/runtime/singe_record_provider.go | 44 + base/runtime/storage.go | 32 + base/runtime/trace_provider.go | 37 + base/template/module.go | 111 +++ base/template/module_test.go | 54 ++ base/updater/doc.go | 2 + base/updater/export.go | 15 + base/updater/fetch.go | 348 ++++++++ base/updater/file.go | 156 ++++ base/updater/filename.go | 57 ++ base/updater/filename_test.go | 80 ++ base/updater/get.go | 91 ++ base/updater/indexes.go | 109 +++ base/updater/indexes_test.go | 57 ++ base/updater/notifier.go | 33 + base/updater/registry.go | 270 ++++++ base/updater/registry_test.go | 35 + base/updater/resource.go | 582 +++++++++++++ base/updater/resource_test.go | 119 +++ base/updater/signing.go | 49 ++ base/updater/state.go | 180 ++++ base/updater/storage.go | 272 ++++++ base/updater/storage_test.go | 68 ++ base/updater/unpacking.go | 195 +++++ base/updater/updating.go | 359 ++++++++ base/utils/atomic.go | 105 +++ base/utils/broadcastflag.go | 84 ++ base/utils/call_limiter.go | 87 ++ base/utils/call_limiter_test.go | 91 ++ base/utils/debug/debug.go | 148 ++++ base/utils/debug/debug_android.go | 31 + base/utils/debug/debug_default.go | 43 + base/utils/fs.go | 51 ++ base/utils/mimetypes.go | 78 ++ base/utils/onceagain.go | 86 ++ base/utils/onceagain_test.go | 60 ++ base/utils/osdetail/colors_windows.go | 51 ++ base/utils/osdetail/command.go | 51 ++ base/utils/osdetail/dnscache_windows.go | 17 + base/utils/osdetail/errors.go | 12 + base/utils/osdetail/service_windows.go | 112 +++ base/utils/osdetail/shell_windows.go | 49 ++ base/utils/osdetail/svchost_windows.go | 120 +++ base/utils/osdetail/version_windows.go | 99 +++ base/utils/osdetail/version_windows_test.go | 29 + base/utils/renameio/LICENSE | 202 +++++ base/utils/renameio/README.md | 55 ++ base/utils/renameio/doc.go | 7 + base/utils/renameio/example_test.go | 57 ++ base/utils/renameio/symlink_test.go | 41 + base/utils/renameio/tempfile.go | 170 ++++ base/utils/renameio/tempfile_linux_test.go | 115 +++ base/utils/renameio/writefile.go | 26 + base/utils/renameio/writefile_test.go | 46 + base/utils/safe.go | 23 + base/utils/safe_test.go | 29 + base/utils/slices.go | 52 ++ base/utils/slices_test.go | 91 ++ base/utils/stablepool.go | 118 +++ base/utils/stablepool_test.go | 120 +++ base/utils/structure.go | 139 +++ base/utils/structure_test.go | 73 ++ base/utils/uuid.go | 45 + base/utils/uuid_test.go | 71 ++ cmds/hub/build | 2 +- cmds/hub/main.go | 8 +- cmds/notifier/http_api.go | 2 +- cmds/notifier/main.go | 14 +- cmds/notifier/notification.go | 2 +- cmds/notifier/notify.go | 8 +- cmds/notifier/notify_linux.go | 2 +- cmds/notifier/notify_windows.go | 2 +- cmds/notifier/shutdown.go | 4 +- cmds/notifier/spn.go | 6 +- cmds/notifier/subsystems.go | 6 +- cmds/notifier/tray.go | 2 +- cmds/observation-hub/apprise.go | 6 +- cmds/observation-hub/build | 2 +- cmds/observation-hub/main.go | 10 +- cmds/observation-hub/observe.go | 8 +- cmds/portmaster-core/build | 4 +- cmds/portmaster-core/main.go | 10 +- cmds/portmaster-start/build | 2 +- cmds/portmaster-start/logs.go | 8 +- cmds/portmaster-start/main.go | 10 +- cmds/portmaster-start/update.go | 4 +- cmds/portmaster-start/verify.go | 4 +- cmds/portmaster-start/version.go | 2 +- cmds/testsuite/db.go | 4 +- cmds/trafficgen/main.go | 2 +- cmds/updatemgr/main.go | 4 +- cmds/updatemgr/purge.go | 2 +- cmds/updatemgr/release.go | 2 +- cmds/winkext-test/main.go | 2 +- service/broadcasts/api.go | 6 +- service/broadcasts/data.go | 2 +- service/broadcasts/install_info.go | 10 +- service/broadcasts/module.go | 4 +- service/broadcasts/notify.go | 12 +- service/broadcasts/state.go | 2 +- service/compat/api.go | 2 +- service/compat/debug_default.go | 2 +- service/compat/debug_linux.go | 2 +- service/compat/debug_windows.go | 2 +- service/compat/module.go | 4 +- service/compat/notify.go | 8 +- service/compat/selfcheck.go | 4 +- service/compat/wfpstate.go | 2 +- service/core/api.go | 14 +- service/core/base/databases.go | 6 +- service/core/base/global.go | 8 +- service/core/base/logs.go | 6 +- service/core/base/module.go | 8 +- service/core/config.go | 4 +- service/core/core.go | 8 +- service/core/os_windows.go | 4 +- service/core/pmtesting/testing.go | 8 +- service/firewall/api.go | 8 +- service/firewall/config.go | 6 +- service/firewall/dns.go | 4 +- .../interception/ebpf/bandwidth/interface.go | 2 +- .../ebpf/connection_listener/worker.go | 2 +- .../firewall/interception/ebpf/exec/exec.go | 2 +- .../interception/interception_default.go | 2 +- .../interception/interception_windows.go | 2 +- .../firewall/interception/introspection.go | 2 +- service/firewall/interception/module.go | 4 +- .../firewall/interception/nfq/conntrack.go | 2 +- service/firewall/interception/nfq/nfq.go | 2 +- service/firewall/interception/nfq/packet.go | 2 +- .../firewall/interception/nfqueue_linux.go | 2 +- .../windowskext/bandwidth_stats.go | 2 +- .../interception/windowskext/handler.go | 2 +- .../firewall/interception/windowskext/kext.go | 2 +- .../interception/windowskext/packet.go | 2 +- .../interception/windowskext/service.go | 2 +- .../interception/windowskext2/handler.go | 2 +- .../interception/windowskext2/kext.go | 2 +- .../interception/windowskext2/packet.go | 2 +- service/firewall/master.go | 2 +- service/firewall/module.go | 8 +- service/firewall/packet_handler.go | 2 +- service/firewall/prompt.go | 4 +- service/firewall/tunnel.go | 2 +- service/intel/block_reason.go | 2 +- service/intel/customlists/config.go | 2 +- service/intel/customlists/lists.go | 4 +- service/intel/customlists/module.go | 4 +- service/intel/entity.go | 2 +- service/intel/filterlists/bloom.go | 4 +- service/intel/filterlists/cache_version.go | 4 +- service/intel/filterlists/database.go | 8 +- service/intel/filterlists/decoder.go | 4 +- service/intel/filterlists/index.go | 10 +- service/intel/filterlists/lookup.go | 4 +- service/intel/filterlists/module.go | 4 +- service/intel/filterlists/record.go | 2 +- service/intel/filterlists/updater.go | 10 +- service/intel/geoip/database.go | 4 +- service/intel/geoip/location.go | 2 +- service/intel/geoip/module.go | 4 +- service/intel/geoip/regions.go | 2 +- service/intel/geoip/regions_test.go | 2 +- service/intel/module.go | 2 +- service/nameserver/config.go | 2 +- service/nameserver/conflict.go | 2 +- service/nameserver/metrics.go | 6 +- service/nameserver/module.go | 8 +- service/nameserver/nameserver.go | 2 +- service/nameserver/nsutil/nsutil.go | 2 +- service/nameserver/response.go | 2 +- service/netenv/adresses.go | 2 +- service/netenv/api.go | 2 +- service/netenv/dbus_linux.go | 2 +- service/netenv/environment_linux.go | 2 +- service/netenv/environment_windows.go | 4 +- service/netenv/icmp_listener.go | 2 +- service/netenv/location.go | 4 +- service/netenv/main.go | 4 +- service/netenv/network-change.go | 4 +- service/netenv/online-status.go | 4 +- service/netquery/database.go | 6 +- service/netquery/manager.go | 8 +- service/netquery/module_api.go | 16 +- service/netquery/orm/schema_builder.go | 2 +- service/netquery/query_handler.go | 2 +- service/netquery/runtime_query_runner.go | 8 +- service/network/api.go | 8 +- service/network/clean.go | 2 +- service/network/connection.go | 6 +- service/network/database.go | 10 +- service/network/dns.go | 2 +- service/network/iphelper/tables.go | 2 +- service/network/metrics.go | 6 +- service/network/module.go | 4 +- service/network/ports.go | 4 +- service/network/proc/findpid.go | 2 +- service/network/proc/pids_by_user.go | 4 +- service/network/proc/tables.go | 2 +- service/network/state/info.go | 2 +- service/network/state/system_default.go | 2 +- service/network/state/tcp.go | 4 +- service/network/state/udp.go | 4 +- service/process/api.go | 2 +- service/process/config.go | 2 +- service/process/database.go | 4 +- service/process/find.go | 4 +- service/process/module.go | 2 +- service/process/process.go | 4 +- service/process/process_linux.go | 2 +- service/process/profile.go | 2 +- service/process/special.go | 2 +- service/process/tags/appimage_unix.go | 2 +- service/process/tags/svchost_windows.go | 4 +- service/process/tags/winstore_windows.go | 4 +- service/profile/api.go | 6 +- service/profile/binmeta/icon.go | 4 +- service/profile/binmeta/icons.go | 2 +- service/profile/config-update.go | 2 +- service/profile/config.go | 2 +- service/profile/database.go | 10 +- service/profile/fingerprint.go | 2 +- service/profile/framework.go | 2 +- service/profile/get.go | 10 +- service/profile/merge.go | 2 +- service/profile/meta.go | 2 +- service/profile/migrations.go | 8 +- service/profile/module.go | 10 +- service/profile/profile-layered-provider.go | 6 +- service/profile/profile-layered.go | 8 +- service/profile/profile.go | 8 +- service/profile/special.go | 2 +- service/resolver/api.go | 4 +- service/resolver/config.go | 2 +- service/resolver/failing.go | 4 +- service/resolver/ipinfo.go | 4 +- service/resolver/main.go | 8 +- service/resolver/metrics.go | 6 +- service/resolver/namerecord.go | 10 +- service/resolver/resolve.go | 4 +- service/resolver/resolver-env.go | 2 +- service/resolver/resolver-https.go | 2 +- service/resolver/resolver-mdns.go | 2 +- service/resolver/resolver-plain.go | 2 +- service/resolver/resolver-tcp.go | 2 +- service/resolver/resolver.go | 2 +- service/resolver/resolver_test.go | 2 +- service/resolver/resolvers.go | 4 +- service/resolver/reverse.go | 2 +- service/resolver/reverse_test.go | 2 +- service/resolver/rrcache.go | 2 +- service/resolver/scopes.go | 2 +- service/status/module.go | 4 +- service/status/provider.go | 4 +- service/status/records.go | 2 +- service/status/security_level.go | 2 +- service/sync/module.go | 4 +- service/sync/profile.go | 6 +- service/sync/setting_single.go | 6 +- service/sync/settings.go | 4 +- service/sync/util.go | 6 +- service/ui/api.go | 4 +- service/ui/serve.go | 9 +- service/updates/api.go | 6 +- service/updates/config.go | 4 +- service/updates/export.go | 10 +- service/updates/get.go | 2 +- service/updates/helper/electron.go | 4 +- service/updates/helper/indexes.go | 2 +- service/updates/helper/signing.go | 2 +- service/updates/main.go | 10 +- service/updates/notify.go | 2 +- service/updates/os_integration_linux.go | 6 +- service/updates/restart.go | 4 +- service/updates/state.go | 6 +- service/updates/upgrader.go | 14 +- spn/access/api.go | 6 +- spn/access/client.go | 6 +- spn/access/database.go | 4 +- spn/access/module.go | 6 +- spn/access/notify.go | 4 +- spn/access/op_auth.go | 4 +- spn/access/storage.go | 10 +- spn/access/token/module_test.go | 2 +- spn/access/token/pblind.go | 4 +- spn/access/token/request_test.go | 2 +- spn/access/token/scramble.go | 2 +- spn/access/token/token.go | 2 +- spn/access/token/token_test.go | 2 +- spn/access/zones.go | 2 +- spn/cabin/config-public.go | 4 +- spn/cabin/database.go | 4 +- spn/cabin/identity.go | 6 +- spn/cabin/keys.go | 4 +- spn/cabin/module.go | 2 +- spn/cabin/verification.go | 4 +- spn/captain/api.go | 8 +- spn/captain/bootstrap.go | 4 +- spn/captain/client.go | 4 +- spn/captain/config.go | 2 +- spn/captain/establish.go | 2 +- spn/captain/intel.go | 4 +- spn/captain/module.go | 12 +- spn/captain/navigation.go | 4 +- spn/captain/op_gossip.go | 6 +- spn/captain/op_gossip_query.go | 6 +- spn/captain/op_publish.go | 2 +- spn/captain/piers.go | 2 +- spn/captain/public.go | 8 +- spn/captain/status.go | 8 +- spn/crew/connect.go | 2 +- spn/crew/metrics.go | 4 +- spn/crew/module.go | 2 +- spn/crew/op_connect.go | 6 +- spn/crew/op_ping.go | 6 +- spn/crew/sticky.go | 4 +- spn/docks/bandwidth_test.go | 4 +- spn/docks/controller.go | 2 +- spn/docks/crane.go | 8 +- spn/docks/crane_establish.go | 4 +- spn/docks/crane_init.go | 10 +- spn/docks/crane_terminal.go | 2 +- spn/docks/crane_verify.go | 4 +- spn/docks/cranehooks.go | 2 +- spn/docks/hub_import.go | 2 +- spn/docks/metrics.go | 4 +- spn/docks/module.go | 4 +- spn/docks/op_capacity.go | 6 +- spn/docks/op_expand.go | 2 +- spn/docks/op_latency.go | 8 +- spn/docks/op_sync_state.go | 4 +- spn/docks/op_whoami.go | 4 +- spn/docks/terminal_expansion.go | 2 +- spn/hub/database.go | 8 +- spn/hub/hub.go | 4 +- spn/hub/hub_test.go | 2 +- spn/hub/update.go | 8 +- spn/hub/update_test.go | 2 +- spn/navigator/api.go | 4 +- spn/navigator/api_route.go | 4 +- spn/navigator/database.go | 10 +- spn/navigator/intel.go | 2 +- spn/navigator/map.go | 4 +- spn/navigator/measurements.go | 4 +- spn/navigator/metrics.go | 4 +- spn/navigator/module.go | 6 +- spn/navigator/module_test.go | 2 +- spn/navigator/options.go | 2 +- spn/navigator/pin.go | 2 +- spn/navigator/pin_export.go | 2 +- spn/navigator/region.go | 2 +- spn/navigator/routing-profiles.go | 2 +- spn/navigator/update.go | 14 +- spn/patrol/http.go | 4 +- spn/patrol/module.go | 2 +- spn/ships/http.go | 2 +- spn/ships/http_info.go | 6 +- spn/ships/http_info_test.go | 2 +- spn/ships/http_shared.go | 2 +- spn/ships/launch.go | 2 +- spn/ships/module.go | 2 +- spn/ships/ship.go | 2 +- spn/ships/tcp.go | 2 +- spn/sluice/module.go | 4 +- spn/sluice/sluice.go | 2 +- spn/terminal/control_flow.go | 4 +- spn/terminal/errors.go | 2 +- spn/terminal/init.go | 6 +- spn/terminal/metrics.go | 4 +- spn/terminal/module.go | 4 +- spn/terminal/msg.go | 2 +- spn/terminal/msgtypes.go | 4 +- spn/terminal/operation.go | 6 +- spn/terminal/operation_counter.go | 8 +- spn/terminal/session.go | 2 +- spn/terminal/terminal.go | 8 +- spn/terminal/terminal_test.go | 2 +- spn/terminal/testing.go | 4 +- spn/unit/unit_debug.go | 2 +- 576 files changed, 31442 insertions(+), 665 deletions(-) create mode 100644 base/.gitignore create mode 100644 base/README.md create mode 100644 base/api/api_bridge.go create mode 100644 base/api/auth_wrapper.go create mode 100644 base/api/authentication.go create mode 100644 base/api/authentication_test.go create mode 100644 base/api/client/api.go create mode 100644 base/api/client/client.go create mode 100644 base/api/client/const.go create mode 100644 base/api/client/message.go create mode 100644 base/api/client/websocket.go create mode 100644 base/api/config.go create mode 100644 base/api/database.go create mode 100644 base/api/doc.go create mode 100644 base/api/endpoints.go create mode 100644 base/api/endpoints_config.go create mode 100644 base/api/endpoints_debug.go create mode 100644 base/api/endpoints_meta.go create mode 100644 base/api/endpoints_modules.go create mode 100644 base/api/endpoints_test.go create mode 100644 base/api/enriched-response.go create mode 100644 base/api/main.go create mode 100644 base/api/main_test.go create mode 100644 base/api/modules.go create mode 100644 base/api/request.go create mode 100644 base/api/router.go create mode 100644 base/api/testclient/root/index.html create mode 100644 base/api/testclient/serve.go create mode 100644 base/apprise/notify.go create mode 100644 base/config/basic_config.go create mode 100644 base/config/database.go create mode 100644 base/config/doc.go create mode 100644 base/config/expertise.go create mode 100644 base/config/get-safe.go create mode 100644 base/config/get.go create mode 100644 base/config/get_test.go create mode 100644 base/config/main.go create mode 100644 base/config/option.go create mode 100644 base/config/persistence.go create mode 100644 base/config/persistence_test.go create mode 100644 base/config/perspective.go create mode 100644 base/config/registry.go create mode 100644 base/config/registry_test.go create mode 100644 base/config/release.go create mode 100644 base/config/set.go create mode 100644 base/config/set_test.go create mode 100644 base/config/validate.go create mode 100644 base/config/validity.go create mode 100644 base/container/container.go create mode 100644 base/container/container_test.go create mode 100644 base/container/doc.go create mode 100644 base/container/serialization.go create mode 100644 base/database/accessor/accessor-json-bytes.go create mode 100644 base/database/accessor/accessor-json-string.go create mode 100644 base/database/accessor/accessor-struct.go create mode 100644 base/database/accessor/accessor.go create mode 100644 base/database/accessor/accessor_test.go create mode 100644 base/database/boilerplate_test.go create mode 100644 base/database/controller.go create mode 100644 base/database/controllers.go create mode 100644 base/database/database.go create mode 100644 base/database/database_test.go create mode 100644 base/database/dbmodule/db.go create mode 100644 base/database/dbmodule/maintenance.go create mode 100644 base/database/doc.go create mode 100644 base/database/errors.go create mode 100644 base/database/hook.go create mode 100644 base/database/hookbase.go create mode 100644 base/database/interface.go create mode 100644 base/database/interface_cache.go create mode 100644 base/database/interface_cache_test.go create mode 100644 base/database/iterator/iterator.go create mode 100644 base/database/main.go create mode 100644 base/database/maintenance.go create mode 100644 base/database/migration/error.go create mode 100644 base/database/migration/migration.go create mode 100644 base/database/query/README.md create mode 100644 base/database/query/condition-and.go create mode 100644 base/database/query/condition-bool.go create mode 100644 base/database/query/condition-error.go create mode 100644 base/database/query/condition-exists.go create mode 100644 base/database/query/condition-float.go create mode 100644 base/database/query/condition-int.go create mode 100644 base/database/query/condition-not.go create mode 100644 base/database/query/condition-or.go create mode 100644 base/database/query/condition-regex.go create mode 100644 base/database/query/condition-string.go create mode 100644 base/database/query/condition-stringslice.go create mode 100644 base/database/query/condition.go create mode 100644 base/database/query/condition_test.go create mode 100644 base/database/query/operators.go create mode 100644 base/database/query/operators_test.go create mode 100644 base/database/query/parser.go create mode 100644 base/database/query/parser_test.go create mode 100644 base/database/query/query.go create mode 100644 base/database/query/query_test.go create mode 100644 base/database/record/base.go create mode 100644 base/database/record/base_test.go create mode 100644 base/database/record/key.go create mode 100644 base/database/record/meta-bench_test.go create mode 100644 base/database/record/meta-gencode.go create mode 100644 base/database/record/meta-gencode_test.go create mode 100644 base/database/record/meta.colf create mode 100644 base/database/record/meta.gencode create mode 100644 base/database/record/meta.go create mode 100644 base/database/record/record.go create mode 100644 base/database/record/record_test.go create mode 100644 base/database/record/wrapper.go create mode 100644 base/database/record/wrapper_test.go create mode 100644 base/database/registry.go create mode 100644 base/database/storage/badger/badger.go create mode 100644 base/database/storage/badger/badger_test.go create mode 100644 base/database/storage/bbolt/bbolt.go create mode 100644 base/database/storage/bbolt/bbolt_test.go create mode 100644 base/database/storage/errors.go create mode 100644 base/database/storage/fstree/fstree.go create mode 100644 base/database/storage/fstree/fstree_test.go create mode 100644 base/database/storage/hashmap/map.go create mode 100644 base/database/storage/hashmap/map_test.go create mode 100644 base/database/storage/injectbase.go create mode 100644 base/database/storage/interface.go create mode 100644 base/database/storage/sinkhole/sinkhole.go create mode 100644 base/database/storage/storages.go create mode 100644 base/database/subscription.go create mode 100644 base/dataroot/root.go create mode 100644 base/formats/dsd/compression.go create mode 100644 base/formats/dsd/dsd.go create mode 100644 base/formats/dsd/dsd_test.go create mode 100644 base/formats/dsd/format.go create mode 100644 base/formats/dsd/gencode_test.go create mode 100644 base/formats/dsd/http.go create mode 100644 base/formats/dsd/http_test.go create mode 100644 base/formats/dsd/interfaces.go create mode 100644 base/formats/dsd/tests.gencode create mode 100644 base/formats/varint/helpers.go create mode 100644 base/formats/varint/varint.go create mode 100644 base/formats/varint/varint_test.go create mode 100644 base/info/module/flags.go create mode 100644 base/info/version.go create mode 100644 base/log/flags.go create mode 100644 base/log/formatting.go create mode 100644 base/log/formatting_unix.go create mode 100644 base/log/formatting_windows.go create mode 100644 base/log/input.go create mode 100644 base/log/logging.go create mode 100644 base/log/logging_test.go create mode 100644 base/log/output.go create mode 100644 base/log/trace.go create mode 100644 base/log/trace_test.go create mode 100644 base/metrics/api.go create mode 100644 base/metrics/config.go create mode 100644 base/metrics/metric.go create mode 100644 base/metrics/metric_counter.go create mode 100644 base/metrics/metric_counter_fetching.go create mode 100644 base/metrics/metric_export.go create mode 100644 base/metrics/metric_gauge.go create mode 100644 base/metrics/metric_histogram.go create mode 100644 base/metrics/metrics_host.go create mode 100644 base/metrics/metrics_info.go create mode 100644 base/metrics/metrics_logs.go create mode 100644 base/metrics/metrics_runtime.go create mode 100644 base/metrics/module.go create mode 100644 base/metrics/persistence.go create mode 100644 base/metrics/testdata/.gitignore create mode 100644 base/metrics/testdata/README.md create mode 100644 base/metrics/testdata/docker-compose.yml create mode 100644 base/metrics/testdata/grafana/config.ini create mode 100644 base/metrics/testdata/grafana/dashboards/portmaster.yml create mode 100644 base/metrics/testdata/grafana/datasources/datasource.yml create mode 100644 base/notifications/cleaner.go create mode 100644 base/notifications/config.go create mode 100644 base/notifications/database.go create mode 100644 base/notifications/doc.go create mode 100644 base/notifications/module-mirror.go create mode 100644 base/notifications/module.go create mode 100644 base/notifications/notification.go create mode 100644 base/rng/doc.go create mode 100644 base/rng/entropy.go create mode 100644 base/rng/entropy_test.go create mode 100644 base/rng/fullfeed.go create mode 100644 base/rng/fullfeed_test.go create mode 100644 base/rng/get.go create mode 100644 base/rng/get_test.go create mode 100644 base/rng/osfeeder.go create mode 100644 base/rng/rng.go create mode 100644 base/rng/rng_test.go create mode 100644 base/rng/test/.gitignore create mode 100644 base/rng/test/README.md create mode 100644 base/rng/test/main.go create mode 100644 base/rng/tickfeeder.go create mode 100644 base/runtime/module.go create mode 100644 base/runtime/modules_integration.go create mode 100644 base/runtime/provider.go create mode 100644 base/runtime/registry.go create mode 100644 base/runtime/registry_test.go create mode 100644 base/runtime/singe_record_provider.go create mode 100644 base/runtime/storage.go create mode 100644 base/runtime/trace_provider.go create mode 100644 base/template/module.go create mode 100644 base/template/module_test.go create mode 100644 base/updater/doc.go create mode 100644 base/updater/export.go create mode 100644 base/updater/fetch.go create mode 100644 base/updater/file.go create mode 100644 base/updater/filename.go create mode 100644 base/updater/filename_test.go create mode 100644 base/updater/get.go create mode 100644 base/updater/indexes.go create mode 100644 base/updater/indexes_test.go create mode 100644 base/updater/notifier.go create mode 100644 base/updater/registry.go create mode 100644 base/updater/registry_test.go create mode 100644 base/updater/resource.go create mode 100644 base/updater/resource_test.go create mode 100644 base/updater/signing.go create mode 100644 base/updater/state.go create mode 100644 base/updater/storage.go create mode 100644 base/updater/storage_test.go create mode 100644 base/updater/unpacking.go create mode 100644 base/updater/updating.go create mode 100644 base/utils/atomic.go create mode 100644 base/utils/broadcastflag.go create mode 100644 base/utils/call_limiter.go create mode 100644 base/utils/call_limiter_test.go create mode 100644 base/utils/debug/debug.go create mode 100644 base/utils/debug/debug_android.go create mode 100644 base/utils/debug/debug_default.go create mode 100644 base/utils/fs.go create mode 100644 base/utils/mimetypes.go create mode 100644 base/utils/onceagain.go create mode 100644 base/utils/onceagain_test.go create mode 100644 base/utils/osdetail/colors_windows.go create mode 100644 base/utils/osdetail/command.go create mode 100644 base/utils/osdetail/dnscache_windows.go create mode 100644 base/utils/osdetail/errors.go create mode 100644 base/utils/osdetail/service_windows.go create mode 100644 base/utils/osdetail/shell_windows.go create mode 100644 base/utils/osdetail/svchost_windows.go create mode 100644 base/utils/osdetail/version_windows.go create mode 100644 base/utils/osdetail/version_windows_test.go create mode 100644 base/utils/renameio/LICENSE create mode 100644 base/utils/renameio/README.md create mode 100644 base/utils/renameio/doc.go create mode 100644 base/utils/renameio/example_test.go create mode 100644 base/utils/renameio/symlink_test.go create mode 100644 base/utils/renameio/tempfile.go create mode 100644 base/utils/renameio/tempfile_linux_test.go create mode 100644 base/utils/renameio/writefile.go create mode 100644 base/utils/renameio/writefile_test.go create mode 100644 base/utils/safe.go create mode 100644 base/utils/safe_test.go create mode 100644 base/utils/slices.go create mode 100644 base/utils/slices_test.go create mode 100644 base/utils/stablepool.go create mode 100644 base/utils/stablepool_test.go create mode 100644 base/utils/structure.go create mode 100644 base/utils/structure_test.go create mode 100644 base/utils/uuid.go create mode 100644 base/utils/uuid_test.go diff --git a/Earthfile b/Earthfile index 9e9de59a1..ca132a624 100644 --- a/Earthfile +++ b/Earthfile @@ -125,7 +125,7 @@ go-build: # Build all go binaries from the specified in CMDS FOR bin IN $CMDS - RUN --no-cache go build -ldflags="-X github.com/safing/portbase/info.version=${VERSION} -X github.com/safing/portbase/info.buildSource=${SOURCE} -X github.com/safing/portbase/info.buildTime=${BUILD_TIME}" -o "/tmp/build/" ./cmds/${bin} + RUN --no-cache go build -ldflags="-X github.com/safing/portmaster/base/info.version=${VERSION} -X github.com/safing/portmaster/base/info.buildSource=${SOURCE} -X github.com/safing/portmaster/base/info.buildTime=${BUILD_TIME}" -o "/tmp/build/" ./cmds/${bin} END DO +GO_ARCH_STRING --goos="${GOOS}" --goarch="${GOARCH}" --goarm="${GOARM}" diff --git a/assets/icons_default.go b/assets/icons_default.go index 2530f3091..a597eceba 100644 --- a/assets/icons_default.go +++ b/assets/icons_default.go @@ -11,7 +11,7 @@ import ( "golang.org/x/image/draw" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" ) // Colored Icon IDs. diff --git a/base/.gitignore b/base/.gitignore new file mode 100644 index 000000000..764174cff --- /dev/null +++ b/base/.gitignore @@ -0,0 +1,8 @@ +portbase +apitest +misc + +go.mod.* +vendor +go.work +go.work.sum diff --git a/base/README.md b/base/README.md new file mode 100644 index 000000000..4791e9b2d --- /dev/null +++ b/base/README.md @@ -0,0 +1,157 @@ +> **Check out our main project at [safing/portmaster](https://github.com/safing/portmaster)** + +# Portbase + +Portbase helps you quickly take off with your project. It gives you all the basic needs you would have for a service (_not_ tool!). +Here is what is included: + +- `log`: really fast and beautiful logging +- `modules`: a multi stage, dependency aware boot process for your software, also manages tasks +- `config`: simple, live updating and extremely fast configuration storage +- `info`: easily tag your builds with versions, commit hashes, and so on +- `formats`: some handy data encoding libs +- `rng`: a feedable CSPRNG for great randomness +- `database`: intelligent and syncable database with hooks and easy integration with structs, uses buckets with different backends +- `api`: a websocket interface to the database, can be extended with custom http handlers + +Before you continue, a word about this project. It was created to hold the base code for both Portmaster and Gate17. This is also what it will be developed for. If you have a great idea on how to improve portbase, please, by all means, raise an issue and tell us about it, but please also don't be surprised or offended if we ask you to create your own fork to do what you need. Portbase isn't for everyone, it's quite specific to our needs, but we decided to make it easily available to others. + +Portbase is actively maintained, please raise issues. + +## log + +The main goal of this logging package is to be as fast as possible. Logs are sent to a channel only with minimal processing beforehand, so that the service can continue with the important work and write the logs later. + +Second, is beauty, both in form what information is provided and how. + +You can use flags to change the log level on a source file basis. + +## modules requires `log` + +packages may register themselves as modules, to take part in the multi stage boot and coordinated shutdown. + +Registering only requires a name/key and the `prep()`, `start()` and `stop()` functions. + +This is how modules are booted: + +- `init()` available: ~~flags~~, ~~config~~, ~~logging~~, ~~dependencies~~ + - register flags (with the stdlib `flag` library) + - register module +- `module.prep()` available: flags, ~~config~~, ~~logging~~, ~~dependencies~~ + - react to flags + - register config variables + - if an error occurs, return it + - return ErrCleanExit for a clean, successful exit. (eg. you only printed a version) +- `module.start()` available: flags, config, logging, dependencies + - start tasks and workers + - do not log errors while starting, but return them +- `module.stop()` available: flags, config, logging, dependencies + - stop all work (ie. goroutines) + - do not log errors while stopping, but return them + +You can start tasks and workers from your module that are then integrated into the module system and will allow for insights and better control of them in the future. + +## config requires `log` + +The config package stores the configuration in json strings. This may sound a bit weird, but it's very practical. + +There are three layers of configuration - in order of priority: user configuration, default configuration and the fallback values supplied when registering a config variable. + +When using config variables, you get a function that checks if your config variable is still up to date every time. If it did not change, it's _extremely_ fast. But if it, it will fetch the current value, which takes a short while, but does not happen often. + + // This is how you would get a string config variable function. + myVar := GetAsString("my_config_var", "default") + // You then use myVar() directly every time, except when you must guarantee the same value between two calls + if myVar() != "default" { + log.Infof("my_config_var is set to %s", myVar()) + } + // no error handling needed! :) + +WARNING: While these config variable functions are _extremely_ fast, they are _NOT_ thread/goroutine safe! (Use the `Concurrent` wrapper for that!) + +## info + +Info provides a easy way to store your version and build information within the binary. If you use the `build` script to build the program, it will automatically set build information so that you can easily find out when and from which commit a binary was built. + +The `build` script extracts information from the host and the git repo and then calls `go build` with some additional arguments. + +## formats/varint + +This is just a convenience wrapper around `encoding/binary`, because we use varints a lot. + +## formats/dsd requires `formats/varint` + +DSD stands for dynamically structured data. In short, this a generic packer that reacts to the supplied data type. + +- structs are usually json encoded +- []bytes and strings stay the same + +This makes it easier / more efficient to store different data types in a k/v data storage. + +## rng requires `log`, `config` + +This package provides a CSPRNG based on the [Fortuna](https://en.wikipedia.org/wiki/Fortuna_(PRNG)) CSPRNG, devised by Bruce Schneier and Niels Ferguson. Implemented by Jochen Voss, published [on Github](https://github.com/seehuhn/fortuna). + +Only the Generator is used from the `fortuna` package. The feeding system implemented here is configurable and is focused with efficiency in mind. + +While you can feed the RNG yourself, it has two feeders by default: +- It starts with a seed from `crypto/rand` and periodically reseeds from there +- A really simple tickfeeder which extracts entropy from the internal go scheduler using goroutines and is meant to be used under load. + +## database requires `log` +_introduction to be written_ + +## api requires `log`, `database`, `config` +_introduction to be written_ + +## The main program + +If you build everything with modules, your main program should be similar to this - just use an empty import for the modules you need: + + import ( + "os" + "os/signal" + "syscall" + + "github.com/safing/portmaster/base/info" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" + + // include packages here + _ "path/to/my/custom/module" + ) + + func main() { + + // Set Info + info.Set("MySoftware", "1.0.0") + + // Start + err := modules.Start() + if err != nil { + if err == modules.ErrCleanExit { + os.Exit(0) + } else { + os.Exit(1) + } + } + + // Shutdown + // catch interrupt for clean shutdown + signalCh := make(chan os.Signal) + signal.Notify( + signalCh, + os.Interrupt, + syscall.SIGHUP, + syscall.SIGINT, + syscall.SIGTERM, + syscall.SIGQUIT, + ) + select { + case <-signalCh: + log.Warning("main: program was interrupted") + modules.Shutdown() + case <-modules.ShuttingDown(): + } + + } diff --git a/base/api/api_bridge.go b/base/api/api_bridge.go new file mode 100644 index 000000000..45291a36e --- /dev/null +++ b/base/api/api_bridge.go @@ -0,0 +1,173 @@ +package api + +import ( + "bytes" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "path" + "strings" + "sync" + + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/database/storage" +) + +const ( + endpointBridgeRemoteAddress = "websocket-bridge" + apiDatabaseName = "api" +) + +func registerEndpointBridgeDB() error { + if _, err := database.Register(&database.Database{ + Name: apiDatabaseName, + Description: "API Bridge", + StorageType: "injected", + }); err != nil { + return err + } + + _, err := database.InjectDatabase("api", &endpointBridgeStorage{}) + return err +} + +type endpointBridgeStorage struct { + storage.InjectBase +} + +// EndpointBridgeRequest holds a bridged request API request. +type EndpointBridgeRequest struct { + record.Base + sync.Mutex + + Method string + Path string + Query map[string]string + Data []byte + MimeType string +} + +// EndpointBridgeResponse holds a bridged request API response. +type EndpointBridgeResponse struct { + record.Base + sync.Mutex + + MimeType string + Body string +} + +// Get returns a database record. +func (ebs *endpointBridgeStorage) Get(key string) (record.Record, error) { + if key == "" { + return nil, database.ErrNotFound + } + + return callAPI(&EndpointBridgeRequest{ + Method: http.MethodGet, + Path: key, + }) +} + +// Get returns the metadata of a database record. +func (ebs *endpointBridgeStorage) GetMeta(key string) (*record.Meta, error) { + // This interface is an API, always return a fresh copy. + m := &record.Meta{} + m.Update() + return m, nil +} + +// Put stores a record in the database. +func (ebs *endpointBridgeStorage) Put(r record.Record) (record.Record, error) { + if r.DatabaseKey() == "" { + return nil, database.ErrNotFound + } + + // Prepare data. + var ebr *EndpointBridgeRequest + if r.IsWrapped() { + // Only allocate a new struct, if we need it. + ebr = &EndpointBridgeRequest{} + err := record.Unwrap(r, ebr) + if err != nil { + return nil, err + } + } else { + var ok bool + ebr, ok = r.(*EndpointBridgeRequest) + if !ok { + return nil, fmt.Errorf("record not of type *EndpointBridgeRequest, but %T", r) + } + } + + // Override path with key to mitigate sneaky stuff. + ebr.Path = r.DatabaseKey() + return callAPI(ebr) +} + +// ReadOnly returns whether the database is read only. +func (ebs *endpointBridgeStorage) ReadOnly() bool { + return false +} + +func callAPI(ebr *EndpointBridgeRequest) (record.Record, error) { + // Add API prefix to path. + requestURL := path.Join(apiV1Path, ebr.Path) + // Check if path is correct. (Defense in depth) + if !strings.HasPrefix(requestURL, apiV1Path) { + return nil, fmt.Errorf("bridged request for %q violates scope", ebr.Path) + } + + // Apply default Method. + if ebr.Method == "" { + if len(ebr.Data) > 0 { + ebr.Method = http.MethodPost + } else { + ebr.Method = http.MethodGet + } + } + + // Build URL. + u, err := url.ParseRequestURI(requestURL) + if err != nil { + return nil, fmt.Errorf("failed to build bridged request url: %w", err) + } + // Build query values. + if ebr.Query != nil && len(ebr.Query) > 0 { + query := url.Values{} + for k, v := range ebr.Query { + query.Set(k, v) + } + u.RawQuery = query.Encode() + } + + // Create request and response objects. + r := httptest.NewRequest(ebr.Method, u.String(), bytes.NewBuffer(ebr.Data)) + r.RemoteAddr = endpointBridgeRemoteAddress + if ebr.MimeType != "" { + r.Header.Set("Content-Type", ebr.MimeType) + } + w := httptest.NewRecorder() + // Let the API handle the request. + server.Handler.ServeHTTP(w, r) + switch w.Code { + case 200: + // Everything okay, continue. + case 500: + // A Go error was returned internally. + // We can safely return this as an error. + return nil, fmt.Errorf("bridged api call failed: %s", w.Body.String()) + default: + return nil, fmt.Errorf("bridged api call returned unexpected error code %d", w.Code) + } + + response := &EndpointBridgeResponse{ + MimeType: w.Header().Get("Content-Type"), + Body: w.Body.String(), + } + response.SetKey(apiDatabaseName + ":" + ebr.Path) + response.UpdateMeta() + + return response, nil +} diff --git a/base/api/auth_wrapper.go b/base/api/auth_wrapper.go new file mode 100644 index 000000000..f836b3ccc --- /dev/null +++ b/base/api/auth_wrapper.go @@ -0,0 +1,30 @@ +package api + +import "net/http" + +// WrapInAuthHandler wraps a simple http.HandlerFunc into a handler that +// exposes the required API permissions for this handler. +func WrapInAuthHandler(fn http.HandlerFunc, read, write Permission) http.Handler { + return &wrappedAuthenticatedHandler{ + HandlerFunc: fn, + read: read, + write: write, + } +} + +type wrappedAuthenticatedHandler struct { + http.HandlerFunc + + read Permission + write Permission +} + +// ReadPermission returns the read permission for the handler. +func (wah *wrappedAuthenticatedHandler) ReadPermission(r *http.Request) Permission { + return wah.read +} + +// WritePermission returns the write permission for the handler. +func (wah *wrappedAuthenticatedHandler) WritePermission(r *http.Request) Permission { + return wah.write +} diff --git a/base/api/authentication.go b/base/api/authentication.go new file mode 100644 index 000000000..a43512d1c --- /dev/null +++ b/base/api/authentication.go @@ -0,0 +1,601 @@ +package api + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/tevino/abool" + + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/rng" +) + +const ( + sessionCookieName = "Portmaster-API-Token" + sessionCookieTTL = 5 * time.Minute +) + +var ( + apiKeys = make(map[string]*AuthToken) + apiKeysLock sync.Mutex + + authFnSet = abool.New() + authFn AuthenticatorFunc + + sessions = make(map[string]*session) + sessionsLock sync.Mutex + + // ErrAPIAccessDeniedMessage should be wrapped by errors returned by + // AuthenticatorFunc in order to signify a blocked request, including a error + // message for the user. This is an empty message on purpose, as to allow the + // function to define the full text of the error shown to the user. + ErrAPIAccessDeniedMessage = errors.New("") +) + +// Permission defines an API requests permission. +type Permission int8 + +const ( + // NotFound declares that the operation does not exist. + NotFound Permission = -2 + + // Dynamic declares that the operation requires permission to be processed, + // but anyone can execute the operation, as it reacts to permissions itself. + Dynamic Permission = -1 + + // NotSupported declares that the operation is not supported. + NotSupported Permission = 0 + + // PermitAnyone declares that anyone can execute the operation without any + // authentication. + PermitAnyone Permission = 1 + + // PermitUser declares that the operation may be executed by authenticated + // third party applications that are categorized as representing a simple + // user and is limited in access. + PermitUser Permission = 2 + + // PermitAdmin declares that the operation may be executed by authenticated + // third party applications that are categorized as representing an + // administrator and has broad in access. + PermitAdmin Permission = 3 + + // PermitSelf declares that the operation may only be executed by the + // software itself and its own (first party) components. + PermitSelf Permission = 4 +) + +// AuthenticatorFunc is a function that can be set as the authenticator for the +// API endpoint. If none is set, all requests will have full access. +// The returned AuthToken represents the permissions that the request has. +type AuthenticatorFunc func(r *http.Request, s *http.Server) (*AuthToken, error) + +// AuthToken represents either a set of required or granted permissions. +// All attributes must be set when the struct is built and must not be changed +// later. Functions may be called at any time. +// The Write permission implicitly also includes reading. +type AuthToken struct { + Read Permission + Write Permission + ValidUntil *time.Time +} + +type session struct { + sync.Mutex + + token *AuthToken + validUntil time.Time +} + +// Expired returns whether the session has expired. +func (sess *session) Expired() bool { + sess.Lock() + defer sess.Unlock() + + return time.Now().After(sess.validUntil) +} + +// Refresh refreshes the validity of the session with the given TTL. +func (sess *session) Refresh(ttl time.Duration) { + sess.Lock() + defer sess.Unlock() + + sess.validUntil = time.Now().Add(ttl) +} + +// AuthenticatedHandler defines the handler interface to specify custom +// permission for an API handler. The returned permission is the required +// permission for the request to proceed. +type AuthenticatedHandler interface { + ReadPermission(*http.Request) Permission + WritePermission(*http.Request) Permission +} + +// SetAuthenticator sets an authenticator function for the API endpoint. If none is set, all requests will be permitted. +func SetAuthenticator(fn AuthenticatorFunc) error { + if module.Online() { + return ErrAuthenticationImmutable + } + + if !authFnSet.SetToIf(false, true) { + return ErrAuthenticationAlreadySet + } + + authFn = fn + return nil +} + +func authenticateRequest(w http.ResponseWriter, r *http.Request, targetHandler http.Handler, readMethod bool) *AuthToken { + tracer := log.Tracer(r.Context()) + + // Get required permission for target handler. + requiredPermission := PermitSelf + if authdHandler, ok := targetHandler.(AuthenticatedHandler); ok { + if readMethod { + requiredPermission = authdHandler.ReadPermission(r) + } else { + requiredPermission = authdHandler.WritePermission(r) + } + } + + // Check if we need to do any authentication at all. + switch requiredPermission { //nolint:exhaustive + case NotFound: + // Not found. + tracer.Debug("api: no API endpoint registered for this path") + http.Error(w, "Not found.", http.StatusNotFound) + return nil + case NotSupported: + // A read or write permission can be marked as not supported. + tracer.Trace("api: authenticated handler reported: not supported") + http.Error(w, "Method not allowed.", http.StatusMethodNotAllowed) + return nil + case PermitAnyone: + // Don't process permissions, as we don't need them. + tracer.Tracef("api: granted %s access to public handler", r.RemoteAddr) + return &AuthToken{ + Read: PermitAnyone, + Write: PermitAnyone, + } + case Dynamic: + // Continue processing permissions, but treat as PermitAnyone. + requiredPermission = PermitAnyone + } + + // The required permission must match the request permission values after + // handling the specials. + if requiredPermission < PermitAnyone || requiredPermission > PermitSelf { + tracer.Warningf( + "api: handler returned invalid permission: %s (%d)", + requiredPermission, + requiredPermission, + ) + http.Error(w, "Internal server error during authentication.", http.StatusInternalServerError) + return nil + } + + // Authenticate request. + token, handled := checkAuth(w, r, requiredPermission > PermitAnyone) + switch { + case handled: + return nil + case token == nil: + // Use default permissions. + token = &AuthToken{ + Read: PermitAnyone, + Write: PermitAnyone, + } + } + + // Get effective permission for request. + var requestPermission Permission + if readMethod { + requestPermission = token.Read + } else { + requestPermission = token.Write + } + + // Check for valid request permission. + if requestPermission < PermitAnyone || requestPermission > PermitSelf { + tracer.Warningf( + "api: authenticator returned invalid permission: %s (%d)", + requestPermission, + requestPermission, + ) + http.Error(w, "Internal server error during authentication.", http.StatusInternalServerError) + return nil + } + + // Check permission. + if requestPermission < requiredPermission { + // If the token is strictly public, return an authentication request. + if token.Read == PermitAnyone && token.Write == PermitAnyone { + w.Header().Set( + "WWW-Authenticate", + `Bearer realm="Portmaster API" domain="/"`, + ) + http.Error(w, "Authorization required.", http.StatusUnauthorized) + return nil + } + + // Otherwise just inform of insufficient permissions. + http.Error(w, "Insufficient permissions.", http.StatusForbidden) + return nil + } + + tracer.Tracef("api: granted %s access to protected handler", r.RemoteAddr) + + // Make a copy of the AuthToken in order mitigate the handler poisoning the + // token, as changes would apply to future requests. + return &AuthToken{ + Read: token.Read, + Write: token.Write, + } +} + +func checkAuth(w http.ResponseWriter, r *http.Request, authRequired bool) (token *AuthToken, handled bool) { + // Return highest possible permissions in dev mode. + if devMode() { + return &AuthToken{ + Read: PermitSelf, + Write: PermitSelf, + }, false + } + + // Database Bridge Access. + if r.RemoteAddr == endpointBridgeRemoteAddress { + return &AuthToken{ + Read: dbCompatibilityPermission, + Write: dbCompatibilityPermission, + }, false + } + + // Check for valid API key. + token = checkAPIKey(r) + if token != nil { + return token, false + } + + // Check for valid session cookie. + token = checkSessionCookie(r) + if token != nil { + return token, false + } + + // Check if an external authentication method is available. + if !authFnSet.IsSet() { + return nil, false + } + + // Authenticate externally. + token, err := authFn(r, server) + if err != nil { + // Check if the authentication process failed internally. + if !errors.Is(err, ErrAPIAccessDeniedMessage) { + log.Tracer(r.Context()).Errorf("api: authenticator failed: %s", err) + http.Error(w, "Internal server error during authentication.", http.StatusInternalServerError) + return nil, true + } + + // Return authentication failure message if authentication is required. + if authRequired { + log.Tracer(r.Context()).Warningf("api: denying api access from %s", r.RemoteAddr) + http.Error(w, err.Error(), http.StatusForbidden) + return nil, true + } + + return nil, false + } + + // Abort if no token is returned. + if token == nil { + return nil, false + } + + // Create session cookie for authenticated request. + err = createSession(w, r, token) + if err != nil { + log.Tracer(r.Context()).Warningf("api: failed to create session: %s", err) + } + return token, false +} + +func checkAPIKey(r *http.Request) *AuthToken { + // Get API key from request. + key := r.Header.Get("Authorization") + if key == "" { + return nil + } + + // Parse API key. + switch { + case strings.HasPrefix(key, "Bearer "): + key = strings.TrimPrefix(key, "Bearer ") + case strings.HasPrefix(key, "Basic "): + user, pass, _ := r.BasicAuth() + key = user + pass + default: + log.Tracer(r.Context()).Tracef( + "api: provided api key type %s is unsupported", strings.Split(key, " ")[0], + ) + return nil + } + + apiKeysLock.Lock() + defer apiKeysLock.Unlock() + + // Check if the provided API key exists. + token, ok := apiKeys[key] + if !ok { + log.Tracer(r.Context()).Tracef( + "api: provided api key %s... is unknown", key[:4], + ) + return nil + } + + // Abort if the token is expired. + if token.ValidUntil != nil && time.Now().After(*token.ValidUntil) { + log.Tracer(r.Context()).Warningf("api: denying api access from %s using expired token", r.RemoteAddr) + return nil + } + + return token +} + +func updateAPIKeys(_ context.Context, _ interface{}) error { + apiKeysLock.Lock() + defer apiKeysLock.Unlock() + + log.Debug("api: importing possibly updated API keys from config") + + // Delete current keys. + for k := range apiKeys { + delete(apiKeys, k) + } + + // whether or not we found expired API keys that should be removed + // from the setting + hasExpiredKeys := false + + // a list of valid API keys. Used when hasExpiredKeys is set to true. + // in that case we'll update the setting to only contain validAPIKeys + validAPIKeys := []string{} + + // Parse new keys. + for _, key := range configuredAPIKeys() { + u, err := url.Parse(key) + if err != nil { + log.Errorf("api: failed to parse configured API key %s: %s", key, err) + + continue + } + + if u.Path == "" { + log.Errorf("api: malformed API key %s: missing path section", key) + + continue + } + + // Create token with default permissions. + token := &AuthToken{ + Read: PermitAnyone, + Write: PermitAnyone, + } + + // Update with configured permissions. + q := u.Query() + // Parse read permission. + readPermission, err := parseAPIPermission(q.Get("read")) + if err != nil { + log.Errorf("api: invalid API key %s: %s", key, err) + continue + } + token.Read = readPermission + // Parse write permission. + writePermission, err := parseAPIPermission(q.Get("write")) + if err != nil { + log.Errorf("api: invalid API key %s: %s", key, err) + continue + } + token.Write = writePermission + + expireStr := q.Get("expires") + if expireStr != "" { + validUntil, err := time.Parse(time.RFC3339, expireStr) + if err != nil { + log.Errorf("api: invalid API key %s: %s", key, err) + continue + } + + // continue to the next token if this one is already invalid + if time.Now().After(validUntil) { + // mark the key as expired so we'll remove it from the setting afterwards + hasExpiredKeys = true + + continue + } + + token.ValidUntil = &validUntil + } + + // Save token. + apiKeys[u.Path] = token + validAPIKeys = append(validAPIKeys, key) + } + + if hasExpiredKeys { + module.StartLowPriorityMicroTask("api key cleanup", 0, func(ctx context.Context) error { + if err := config.SetConfigOption(CfgAPIKeys, validAPIKeys); err != nil { + log.Errorf("api: failed to remove expired API keys: %s", err) + } else { + log.Infof("api: removed expired API keys from %s", CfgAPIKeys) + } + + return nil + }) + } + + return nil +} + +func checkSessionCookie(r *http.Request) *AuthToken { + // Get session cookie from request. + c, err := r.Cookie(sessionCookieName) + if err != nil { + return nil + } + + // Check if session cookie is registered. + sessionsLock.Lock() + sess, ok := sessions[c.Value] + sessionsLock.Unlock() + if !ok { + log.Tracer(r.Context()).Tracef("api: provided session cookie %s is unknown", c.Value) + return nil + } + + // Check if session is still valid. + if sess.Expired() { + log.Tracer(r.Context()).Tracef("api: provided session cookie %s has expired", c.Value) + return nil + } + + // Refresh session and return. + sess.Refresh(sessionCookieTTL) + log.Tracer(r.Context()).Tracef("api: session cookie %s is valid, refreshing", c.Value) + return sess.token +} + +func createSession(w http.ResponseWriter, r *http.Request, token *AuthToken) error { + // Generate new session key. + secret, err := rng.Bytes(32) // 256 bit + if err != nil { + return err + } + sessionKey := base64.RawURLEncoding.EncodeToString(secret) + + // Set token cookie in response. + http.SetCookie(w, &http.Cookie{ + Name: sessionCookieName, + Value: sessionKey, + Path: "/", + HttpOnly: true, + SameSite: http.SameSiteStrictMode, + }) + + // Create session. + sess := &session{ + token: token, + } + sess.Refresh(sessionCookieTTL) + + // Save session. + sessionsLock.Lock() + defer sessionsLock.Unlock() + sessions[sessionKey] = sess + log.Tracer(r.Context()).Debug("api: issued session cookie") + + return nil +} + +func cleanSessions(_ context.Context, _ *modules.Task) error { + sessionsLock.Lock() + defer sessionsLock.Unlock() + + for sessionKey, sess := range sessions { + if sess.Expired() { + delete(sessions, sessionKey) + } + } + + return nil +} + +func deleteSession(sessionKey string) { + sessionsLock.Lock() + defer sessionsLock.Unlock() + + delete(sessions, sessionKey) +} + +func getEffectiveMethod(r *http.Request) (eMethod string, readMethod bool, ok bool) { + method := r.Method + + // Get CORS request method if OPTIONS request. + if r.Method == http.MethodOptions { + method = r.Header.Get("Access-Control-Request-Method") + if method == "" { + return "", false, false + } + } + + switch method { + case http.MethodGet, http.MethodHead: + return http.MethodGet, true, true + case http.MethodPost, http.MethodPut, http.MethodDelete: + return method, false, true + default: + return "", false, false + } +} + +func parseAPIPermission(s string) (Permission, error) { + switch strings.ToLower(s) { + case "", "anyone": + return PermitAnyone, nil + case "user": + return PermitUser, nil + case "admin": + return PermitAdmin, nil + default: + return PermitAnyone, fmt.Errorf("invalid permission: %s", s) + } +} + +func (p Permission) String() string { + switch p { + case NotSupported: + return "NotSupported" + case Dynamic: + return "Dynamic" + case PermitAnyone: + return "PermitAnyone" + case PermitUser: + return "PermitUser" + case PermitAdmin: + return "PermitAdmin" + case PermitSelf: + return "PermitSelf" + case NotFound: + return "NotFound" + default: + return "Unknown" + } +} + +// Role returns a string representation of the permission role. +func (p Permission) Role() string { + switch p { + case PermitAnyone: + return "Anyone" + case PermitUser: + return "User" + case PermitAdmin: + return "Admin" + case PermitSelf: + return "Self" + case Dynamic, NotFound, NotSupported: + return "Invalid" + default: + return "Invalid" + } +} diff --git a/base/api/authentication_test.go b/base/api/authentication_test.go new file mode 100644 index 000000000..3d7e7c504 --- /dev/null +++ b/base/api/authentication_test.go @@ -0,0 +1,194 @@ +package api + +import ( + "errors" + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +var testToken = new(AuthToken) + +func testAuthenticator(r *http.Request, s *http.Server) (*AuthToken, error) { + switch { + case testToken.Read == -127 || testToken.Write == -127: + return nil, errors.New("test error") + case testToken.Read == -128 || testToken.Write == -128: + return nil, fmt.Errorf("%wdenied", ErrAPIAccessDeniedMessage) + default: + return testToken, nil + } +} + +type testAuthHandler struct { + Read Permission + Write Permission +} + +func (ah *testAuthHandler) ReadPermission(r *http.Request) Permission { + return ah.Read +} + +func (ah *testAuthHandler) WritePermission(r *http.Request) Permission { + return ah.Write +} + +func (ah *testAuthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Check if request is as expected. + ar := GetAPIRequest(r) + switch { + case ar == nil: + http.Error(w, "ar == nil", http.StatusInternalServerError) + case ar.AuthToken == nil: + http.Error(w, "ar.AuthToken == nil", http.StatusInternalServerError) + default: + http.Error(w, "auth success", http.StatusOK) + } +} + +func makeAuthTestPath(reading bool, p Permission) string { + if reading { + return fmt.Sprintf("/test/auth/read/%s", p) + } + return fmt.Sprintf("/test/auth/write/%s", p) +} + +func init() { + // Set test authenticator. + err := SetAuthenticator(testAuthenticator) + if err != nil { + panic(err) + } +} + +func TestPermissions(t *testing.T) { + t.Parallel() + + testHandler := &mainHandler{ + mux: mainMux, + } + + // Define permissions that need testing. + permissionsToTest := []Permission{ + NotSupported, + PermitAnyone, + PermitUser, + PermitAdmin, + PermitSelf, + Dynamic, + NotFound, + 100, // Test a too high value. + -100, // Test a too low value. + -127, // Simulate authenticator failure. + -128, // Simulate authentication denied message. + } + + // Register test handlers. + for _, p := range permissionsToTest { + RegisterHandler(makeAuthTestPath(true, p), &testAuthHandler{Read: p}) + RegisterHandler(makeAuthTestPath(false, p), &testAuthHandler{Write: p}) + } + + // Test all the combinations. + for _, requestPerm := range permissionsToTest { + for _, handlerPerm := range permissionsToTest { + for _, method := range []string{ + http.MethodGet, + http.MethodHead, + http.MethodPost, + http.MethodPut, + http.MethodDelete, + } { + + // Set request permission for test requests. + _, reading, _ := getEffectiveMethod(&http.Request{Method: method}) + if reading { + testToken.Read = requestPerm + testToken.Write = NotSupported + } else { + testToken.Read = NotSupported + testToken.Write = requestPerm + } + + // Evaluate expected result. + var expectSuccess bool + switch { + case handlerPerm == PermitAnyone: + // This is fast-tracked. There are not additional checks. + expectSuccess = true + case handlerPerm == Dynamic: + // This is turned into PermitAnyone in the authenticator. + // But authentication is still processed and the result still gets + // sanity checked! + if requestPerm >= PermitAnyone && + requestPerm <= PermitSelf { + expectSuccess = true + } + // Another special case is when the handler requires permission to be + // processed but the authenticator fails to authenticate the request. + // In this case, a fallback token with PermitAnyone is used. + if requestPerm == -128 { + // -128 is used to simulate a permission denied message. + expectSuccess = true + } + case handlerPerm <= NotSupported: + // Invalid handler permission. + case handlerPerm > PermitSelf: + // Invalid handler permission. + case requestPerm <= NotSupported: + // Invalid request permission. + case requestPerm > PermitSelf: + // Invalid request permission. + case requestPerm < handlerPerm: + // Valid, but insufficient request permission. + default: + expectSuccess = true + } + + if expectSuccess { + // Test for success. + if !assert.HTTPBodyContains( + t, + testHandler.ServeHTTP, + method, + makeAuthTestPath(reading, handlerPerm), + nil, + "auth success", + ) { + t.Errorf( + "%s with %s (%d) to handler %s (%d)", + method, + requestPerm, requestPerm, + handlerPerm, handlerPerm, + ) + } + } else { + // Test for error. + if !assert.HTTPError(t, + testHandler.ServeHTTP, + method, + makeAuthTestPath(reading, handlerPerm), + nil, + ) { + t.Errorf( + "%s with %s (%d) to handler %s (%d)", + method, + requestPerm, requestPerm, + handlerPerm, handlerPerm, + ) + } + } + } + } + } +} + +func TestPermissionDefinitions(t *testing.T) { + t.Parallel() + + if NotSupported != 0 { + t.Fatalf("NotSupported must be zero, was %v", NotSupported) + } +} diff --git a/base/api/client/api.go b/base/api/client/api.go new file mode 100644 index 000000000..e1a166448 --- /dev/null +++ b/base/api/client/api.go @@ -0,0 +1,57 @@ +package client + +// Get sends a get command to the API. +func (c *Client) Get(key string, handleFunc func(*Message)) *Operation { + op := c.NewOperation(handleFunc) + op.Send(msgRequestGet, key, nil) + return op +} + +// Query sends a query command to the API. +func (c *Client) Query(query string, handleFunc func(*Message)) *Operation { + op := c.NewOperation(handleFunc) + op.Send(msgRequestQuery, query, nil) + return op +} + +// Sub sends a sub command to the API. +func (c *Client) Sub(query string, handleFunc func(*Message)) *Operation { + op := c.NewOperation(handleFunc) + op.Send(msgRequestSub, query, nil) + return op +} + +// Qsub sends a qsub command to the API. +func (c *Client) Qsub(query string, handleFunc func(*Message)) *Operation { + op := c.NewOperation(handleFunc) + op.Send(msgRequestQsub, query, nil) + return op +} + +// Create sends a create command to the API. +func (c *Client) Create(key string, value interface{}, handleFunc func(*Message)) *Operation { + op := c.NewOperation(handleFunc) + op.Send(msgRequestCreate, key, value) + return op +} + +// Update sends an update command to the API. +func (c *Client) Update(key string, value interface{}, handleFunc func(*Message)) *Operation { + op := c.NewOperation(handleFunc) + op.Send(msgRequestUpdate, key, value) + return op +} + +// Insert sends an insert command to the API. +func (c *Client) Insert(key string, value interface{}, handleFunc func(*Message)) *Operation { + op := c.NewOperation(handleFunc) + op.Send(msgRequestInsert, key, value) + return op +} + +// Delete sends a delete command to the API. +func (c *Client) Delete(key string, handleFunc func(*Message)) *Operation { + op := c.NewOperation(handleFunc) + op.Send(msgRequestDelete, key, nil) + return op +} diff --git a/base/api/client/client.go b/base/api/client/client.go new file mode 100644 index 000000000..c2bd037ac --- /dev/null +++ b/base/api/client/client.go @@ -0,0 +1,240 @@ +package client + +import ( + "fmt" + "sync" + "time" + + "github.com/tevino/abool" + + "github.com/safing/portmaster/base/log" +) + +const ( + backOffTimer = 1 * time.Second + + offlineSignal uint8 = 0 + onlineSignal uint8 = 1 +) + +// The Client enables easy interaction with the API. +type Client struct { + sync.Mutex + + server string + + onlineSignal chan struct{} + offlineSignal chan struct{} + shutdownSignal chan struct{} + lastSignal uint8 + + send chan *Message + resend chan *Message + recv chan *Message + + operations map[string]*Operation + nextOpID uint64 + + lastError string +} + +// NewClient returns a new Client. +func NewClient(server string) *Client { + c := &Client{ + server: server, + onlineSignal: make(chan struct{}), + offlineSignal: make(chan struct{}), + shutdownSignal: make(chan struct{}), + lastSignal: offlineSignal, + send: make(chan *Message, 100), + resend: make(chan *Message, 1), + recv: make(chan *Message, 100), + operations: make(map[string]*Operation), + } + go c.handler() + return c +} + +// Connect connects to the API once. +func (c *Client) Connect() error { + defer c.signalOffline() + + err := c.wsConnect() + if err != nil && err.Error() != c.lastError { + log.Errorf("client: error connecting to Portmaster: %s", err) + c.lastError = err.Error() + } + return err +} + +// StayConnected calls Connect again whenever the connection is lost. +func (c *Client) StayConnected() { + log.Infof("client: connecting to Portmaster at %s", c.server) + + _ = c.Connect() + for { + select { + case <-time.After(backOffTimer): + log.Infof("client: reconnecting...") + _ = c.Connect() + case <-c.shutdownSignal: + return + } + } +} + +// Shutdown shuts the client down. +func (c *Client) Shutdown() { + select { + case <-c.shutdownSignal: + default: + close(c.shutdownSignal) + } +} + +func (c *Client) signalOnline() { + c.Lock() + defer c.Unlock() + if c.lastSignal == offlineSignal { + log.Infof("client: went online") + c.offlineSignal = make(chan struct{}) + close(c.onlineSignal) + c.lastSignal = onlineSignal + + // resend unsent request + for _, op := range c.operations { + if op.resuscitationEnabled.IsSet() && op.request.sent != nil && op.request.sent.SetToIf(true, false) { + op.client.send <- op.request + log.Infof("client: resuscitated %s %s %s", op.request.OpID, op.request.Type, op.request.Key) + } + } + + } +} + +func (c *Client) signalOffline() { + c.Lock() + defer c.Unlock() + if c.lastSignal == onlineSignal { + log.Infof("client: went offline") + c.onlineSignal = make(chan struct{}) + close(c.offlineSignal) + c.lastSignal = offlineSignal + + // signal offline status to operations + for _, op := range c.operations { + op.handle(&Message{ + OpID: op.ID, + Type: MsgOffline, + }) + } + + } +} + +// Online returns a closed channel read if the client is connected to the API. +func (c *Client) Online() <-chan struct{} { + c.Lock() + defer c.Unlock() + return c.onlineSignal +} + +// Offline returns a closed channel read if the client is not connected to the API. +func (c *Client) Offline() <-chan struct{} { + c.Lock() + defer c.Unlock() + return c.offlineSignal +} + +func (c *Client) handler() { + for { + select { + + case m := <-c.recv: + + if m == nil { + return + } + + c.Lock() + op, ok := c.operations[m.OpID] + c.Unlock() + + if ok { + log.Tracef("client: [%s] received %s msg: %s", m.OpID, m.Type, m.Key) + op.handle(m) + } else { + log.Tracef("client: received message for unknown operation %s", m.OpID) + } + + case <-c.shutdownSignal: + return + + } + } +} + +// Operation represents a single operation by a client. +type Operation struct { + ID string + request *Message + client *Client + handleFunc func(*Message) + handler chan *Message + resuscitationEnabled *abool.AtomicBool +} + +func (op *Operation) handle(m *Message) { + if op.handleFunc != nil { + op.handleFunc(m) + } else { + select { + case op.handler <- m: + default: + log.Warningf("client: handler channel of operation %s overflowed", op.ID) + } + } +} + +// Cancel the operation. +func (op *Operation) Cancel() { + op.client.Lock() + defer op.client.Unlock() + delete(op.client.operations, op.ID) + close(op.handler) +} + +// Send sends a request to the API. +func (op *Operation) Send(command, text string, data interface{}) { + op.request = &Message{ + OpID: op.ID, + Type: command, + Key: text, + Value: data, + sent: abool.NewBool(false), + } + log.Tracef("client: [%s] sending %s msg: %s", op.request.OpID, op.request.Type, op.request.Key) + op.client.send <- op.request +} + +// EnableResuscitation will resend the request after reconnecting to the API. +func (op *Operation) EnableResuscitation() { + op.resuscitationEnabled.Set() +} + +// NewOperation returns a new operation. +func (c *Client) NewOperation(handleFunc func(*Message)) *Operation { + c.Lock() + defer c.Unlock() + + c.nextOpID++ + op := &Operation{ + ID: fmt.Sprintf("#%d", c.nextOpID), + client: c, + handleFunc: handleFunc, + handler: make(chan *Message, 100), + resuscitationEnabled: abool.NewBool(false), + } + c.operations[op.ID] = op + return op +} diff --git a/base/api/client/const.go b/base/api/client/const.go new file mode 100644 index 000000000..d882683d5 --- /dev/null +++ b/base/api/client/const.go @@ -0,0 +1,28 @@ +package client + +// Message Types. +const ( + msgRequestGet = "get" + msgRequestQuery = "query" + msgRequestSub = "sub" + msgRequestQsub = "qsub" + msgRequestCreate = "create" + msgRequestUpdate = "update" + msgRequestInsert = "insert" + msgRequestDelete = "delete" + + MsgOk = "ok" + MsgError = "error" + MsgDone = "done" + MsgSuccess = "success" + MsgUpdate = "upd" + MsgNew = "new" + MsgDelete = "del" + MsgWarning = "warning" + + MsgOffline = "offline" // special message type for signaling the handler that the connection was lost + + apiSeperator = "|" +) + +var apiSeperatorBytes = []byte(apiSeperator) diff --git a/base/api/client/message.go b/base/api/client/message.go new file mode 100644 index 000000000..0927eb035 --- /dev/null +++ b/base/api/client/message.go @@ -0,0 +1,95 @@ +package client + +import ( + "bytes" + "errors" + + "github.com/tevino/abool" + + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/formats/dsd" +) + +// ErrMalformedMessage is returned when a malformed message was encountered. +var ErrMalformedMessage = errors.New("malformed message") + +// Message is an API message. +type Message struct { + OpID string + Type string + Key string + RawValue []byte + Value interface{} + sent *abool.AtomicBool +} + +// ParseMessage parses the given raw data and returns a Message. +func ParseMessage(data []byte) (*Message, error) { + parts := bytes.SplitN(data, apiSeperatorBytes, 4) + if len(parts) < 2 { + return nil, ErrMalformedMessage + } + + m := &Message{ + OpID: string(parts[0]), + Type: string(parts[1]), + } + + switch m.Type { + case MsgOk, MsgUpdate, MsgNew: + // parse key and data + // 127|ok|| + // 127|upd|| + // 127|new|| + if len(parts) != 4 { + return nil, ErrMalformedMessage + } + m.Key = string(parts[2]) + m.RawValue = parts[3] + case MsgDelete: + // parse key + // 127|del| + if len(parts) != 3 { + return nil, ErrMalformedMessage + } + m.Key = string(parts[2]) + case MsgWarning, MsgError: + // parse message + // 127|error| + // 127|warning| // error with single record, operation continues + if len(parts) != 3 { + return nil, ErrMalformedMessage + } + m.Key = string(parts[2]) + case MsgDone, MsgSuccess: + // nothing more to do + // 127|success + // 127|done + } + + return m, nil +} + +// Pack serializes a message into a []byte slice. +func (m *Message) Pack() ([]byte, error) { + c := container.New([]byte(m.OpID), apiSeperatorBytes, []byte(m.Type)) + + if m.Key != "" { + c.Append(apiSeperatorBytes) + c.Append([]byte(m.Key)) + if len(m.RawValue) > 0 { + c.Append(apiSeperatorBytes) + c.Append(m.RawValue) + } else if m.Value != nil { + var err error + m.RawValue, err = dsd.Dump(m.Value, dsd.JSON) + if err != nil { + return nil, err + } + c.Append(apiSeperatorBytes) + c.Append(m.RawValue) + } + } + + return c.CompileData(), nil +} diff --git a/base/api/client/websocket.go b/base/api/client/websocket.go new file mode 100644 index 000000000..a1f4b2339 --- /dev/null +++ b/base/api/client/websocket.go @@ -0,0 +1,121 @@ +package client + +import ( + "fmt" + "sync" + + "github.com/gorilla/websocket" + "github.com/tevino/abool" + + "github.com/safing/portmaster/base/log" +) + +type wsState struct { + wsConn *websocket.Conn + wg sync.WaitGroup + failing *abool.AtomicBool + failSignal chan struct{} +} + +func (c *Client) wsConnect() error { + state := &wsState{ + failing: abool.NewBool(false), + failSignal: make(chan struct{}), + } + + var err error + state.wsConn, _, err = websocket.DefaultDialer.Dial(fmt.Sprintf("ws://%s/api/database/v1", c.server), nil) + if err != nil { + return err + } + + c.signalOnline() + + state.wg.Add(2) + go c.wsReader(state) + go c.wsWriter(state) + + // wait for end of connection + select { + case <-state.failSignal: + case <-c.shutdownSignal: + state.Error("") + } + _ = state.wsConn.Close() + state.wg.Wait() + + return nil +} + +func (c *Client) wsReader(state *wsState) { + defer state.wg.Done() + for { + _, data, err := state.wsConn.ReadMessage() + log.Tracef("client: read message") + if err != nil { + if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + state.Error(fmt.Sprintf("client: read error: %s", err)) + } else { + state.Error("client: connection closed by server") + } + return + } + log.Tracef("client: received message: %s", string(data)) + m, err := ParseMessage(data) + if err != nil { + log.Warningf("client: failed to parse message: %s", err) + } else { + select { + case c.recv <- m: + case <-state.failSignal: + return + } + } + } +} + +func (c *Client) wsWriter(state *wsState) { + defer state.wg.Done() + for { + select { + case <-state.failSignal: + return + case m := <-c.resend: + data, err := m.Pack() + if err == nil { + err = state.wsConn.WriteMessage(websocket.BinaryMessage, data) + } + if err != nil { + state.Error(fmt.Sprintf("client: write error: %s", err)) + return + } + log.Tracef("client: sent message: %s", string(data)) + if m.sent != nil { + m.sent.Set() + } + case m := <-c.send: + data, err := m.Pack() + if err == nil { + err = state.wsConn.WriteMessage(websocket.BinaryMessage, data) + } + if err != nil { + c.resend <- m + state.Error(fmt.Sprintf("client: write error: %s", err)) + return + } + log.Tracef("client: sent message: %s", string(data)) + if m.sent != nil { + m.sent.Set() + } + } + } +} + +func (state *wsState) Error(message string) { + if state.failing.SetToIf(false, true) { + close(state.failSignal) + if message != "" { + log.Warning(message) + } + } +} diff --git a/base/api/config.go b/base/api/config.go new file mode 100644 index 000000000..4b55128f0 --- /dev/null +++ b/base/api/config.go @@ -0,0 +1,91 @@ +package api + +import ( + "flag" + + "github.com/safing/portmaster/base/config" +) + +// Config Keys. +const ( + CfgDefaultListenAddressKey = "core/listenAddress" + CfgAPIKeys = "core/apiKeys" +) + +var ( + listenAddressFlag string + listenAddressConfig config.StringOption + defaultListenAddress string + + configuredAPIKeys config.StringArrayOption + + devMode config.BoolOption +) + +func init() { + flag.StringVar( + &listenAddressFlag, + "api-address", + "", + "set api listen address; configuration is stronger", + ) +} + +func getDefaultListenAddress() string { + // check if overridden + if listenAddressFlag != "" { + return listenAddressFlag + } + // return internal default + return defaultListenAddress +} + +func registerConfig() error { + err := config.Register(&config.Option{ + Name: "API Listen Address", + Key: CfgDefaultListenAddressKey, + Description: "Defines the IP address and port on which the internal API listens.", + OptType: config.OptTypeString, + ExpertiseLevel: config.ExpertiseLevelDeveloper, + ReleaseLevel: config.ReleaseLevelStable, + DefaultValue: getDefaultListenAddress(), + ValidationRegex: "^([0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}:[0-9]{1,5}|\\[[:0-9A-Fa-f]+\\]:[0-9]{1,5})$", + RequiresRestart: true, + Annotations: config.Annotations{ + config.DisplayOrderAnnotation: 513, + config.CategoryAnnotation: "Development", + }, + }) + if err != nil { + return err + } + listenAddressConfig = config.GetAsString(CfgDefaultListenAddressKey, getDefaultListenAddress()) + + err = config.Register(&config.Option{ + Name: "API Keys", + Key: CfgAPIKeys, + Description: "Define API keys for privileged access to the API. Every entry is a separate API key with respective permissions. Format is `?read=&write=`. Permissions are `anyone`, `user` and `admin`, and may be omitted.", + Sensitive: true, + OptType: config.OptTypeStringArray, + ExpertiseLevel: config.ExpertiseLevelDeveloper, + ReleaseLevel: config.ReleaseLevelStable, + DefaultValue: []string{}, + Annotations: config.Annotations{ + config.DisplayOrderAnnotation: 514, + config.CategoryAnnotation: "Development", + }, + }) + if err != nil { + return err + } + configuredAPIKeys = config.GetAsStringArray(CfgAPIKeys, []string{}) + + devMode = config.Concurrent.GetAsBool(config.CfgDevModeKey, false) + + return nil +} + +// SetDefaultAPIListenAddress sets the default listen address for the API. +func SetDefaultAPIListenAddress(address string) { + defaultListenAddress = address +} diff --git a/base/api/database.go b/base/api/database.go new file mode 100644 index 000000000..e7bdb7073 --- /dev/null +++ b/base/api/database.go @@ -0,0 +1,698 @@ +package api + +import ( + "bytes" + "context" + "errors" + "fmt" + "net/http" + "sync" + + "github.com/gorilla/websocket" + "github.com/tevino/abool" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/iterator" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/base/formats/varint" + "github.com/safing/portmaster/base/log" +) + +const ( + dbMsgTypeOk = "ok" + dbMsgTypeError = "error" + dbMsgTypeDone = "done" + dbMsgTypeSuccess = "success" + dbMsgTypeUpd = "upd" + dbMsgTypeNew = "new" + dbMsgTypeDel = "del" + dbMsgTypeWarning = "warning" + + dbAPISeperator = "|" + emptyString = "" +) + +var ( + dbAPISeperatorBytes = []byte(dbAPISeperator) + dbCompatibilityPermission = PermitAdmin +) + +func init() { + RegisterHandler("/api/database/v1", WrapInAuthHandler( + startDatabaseWebsocketAPI, + // Default to admin read/write permissions until the database gets support + // for api permissions. + dbCompatibilityPermission, + dbCompatibilityPermission, + )) +} + +// DatabaseAPI is a generic database API interface. +type DatabaseAPI struct { + queriesLock sync.Mutex + queries map[string]*iterator.Iterator + + subsLock sync.Mutex + subs map[string]*database.Subscription + + shutdownSignal chan struct{} + shuttingDown *abool.AtomicBool + db *database.Interface + + sendBytes func(data []byte) +} + +// DatabaseWebsocketAPI is a database websocket API interface. +type DatabaseWebsocketAPI struct { + DatabaseAPI + + sendQueue chan []byte + conn *websocket.Conn +} + +func allowAnyOrigin(r *http.Request) bool { + return true +} + +// CreateDatabaseAPI creates a new database interface. +func CreateDatabaseAPI(sendFunction func(data []byte)) DatabaseAPI { + return DatabaseAPI{ + queries: make(map[string]*iterator.Iterator), + subs: make(map[string]*database.Subscription), + shutdownSignal: make(chan struct{}), + shuttingDown: abool.NewBool(false), + db: database.NewInterface(nil), + sendBytes: sendFunction, + } +} + +func startDatabaseWebsocketAPI(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{ + CheckOrigin: allowAnyOrigin, + ReadBufferSize: 1024, + WriteBufferSize: 65536, + } + wsConn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + errMsg := fmt.Sprintf("could not upgrade: %s", err) + log.Error(errMsg) + http.Error(w, errMsg, http.StatusBadRequest) + return + } + + newDBAPI := &DatabaseWebsocketAPI{ + DatabaseAPI: DatabaseAPI{ + queries: make(map[string]*iterator.Iterator), + subs: make(map[string]*database.Subscription), + shutdownSignal: make(chan struct{}), + shuttingDown: abool.NewBool(false), + db: database.NewInterface(nil), + }, + + sendQueue: make(chan []byte, 100), + conn: wsConn, + } + + newDBAPI.sendBytes = func(data []byte) { + newDBAPI.sendQueue <- data + } + + module.StartWorker("database api handler", newDBAPI.handler) + module.StartWorker("database api writer", newDBAPI.writer) + + log.Tracer(r.Context()).Infof("api request: init websocket %s %s", r.RemoteAddr, r.RequestURI) +} + +func (api *DatabaseWebsocketAPI) handler(context.Context) error { + defer func() { + _ = api.shutdown(nil) + }() + + for { + _, msg, err := api.conn.ReadMessage() + if err != nil { + return api.shutdown(err) + } + + api.Handle(msg) + } +} + +func (api *DatabaseWebsocketAPI) writer(ctx context.Context) error { + defer func() { + _ = api.shutdown(nil) + }() + + var data []byte + var err error + + for { + select { + // prioritize direct writes + case data = <-api.sendQueue: + if len(data) == 0 { + return nil + } + case <-ctx.Done(): + return nil + case <-api.shutdownSignal: + return nil + } + + // log.Tracef("api: sending %s", string(*msg)) + err = api.conn.WriteMessage(websocket.BinaryMessage, data) + if err != nil { + return api.shutdown(err) + } + } +} + +func (api *DatabaseWebsocketAPI) shutdown(err error) error { + // Check if we are the first to shut down. + if !api.shuttingDown.SetToIf(false, true) { + return nil + } + + // Check the given error. + if err != nil { + if websocket.IsCloseError(err, + websocket.CloseNormalClosure, + websocket.CloseGoingAway, + websocket.CloseAbnormalClosure, + ) { + log.Infof("api: websocket connection to %s closed", api.conn.RemoteAddr()) + } else { + log.Warningf("api: websocket connection error with %s: %s", api.conn.RemoteAddr(), err) + } + } + + // Trigger shutdown. + close(api.shutdownSignal) + _ = api.conn.Close() + return nil +} + +// Handle handles a message for the database API. +func (api *DatabaseAPI) Handle(msg []byte) { + // 123|get| + // 123|ok|| + // 123|error| + // 124|query| + // 124|ok|| + // 124|done + // 124|error| + // 124|warning| // error with single record, operation continues + // 124|cancel + // 125|sub| + // 125|upd|| + // 125|new|| + // 127|del| + // 125|warning| // error with single record, operation continues + // 125|cancel + // 127|qsub| + // 127|ok|| + // 127|done + // 127|error| + // 127|upd|| + // 127|new|| + // 127|del| + // 127|warning| // error with single record, operation continues + // 127|cancel + + // 128|create|| + // 128|success + // 128|error| + // 129|update|| + // 129|success + // 129|error| + // 130|insert|| + // 130|success + // 130|error| + // 131|delete| + // 131|success + // 131|error| + + parts := bytes.SplitN(msg, []byte("|"), 3) + + // Handle special command "cancel" + if len(parts) == 2 && string(parts[1]) == "cancel" { + // 124|cancel + // 125|cancel + // 127|cancel + go api.handleCancel(parts[0]) + return + } + + if len(parts) != 3 { + api.send(nil, dbMsgTypeError, "bad request: malformed message", nil) + return + } + + switch string(parts[1]) { + case "get": + // 123|get| + go api.handleGet(parts[0], string(parts[2])) + case "query": + // 124|query| + go api.handleQuery(parts[0], string(parts[2])) + case "sub": + // 125|sub| + go api.handleSub(parts[0], string(parts[2])) + case "qsub": + // 127|qsub| + go api.handleQsub(parts[0], string(parts[2])) + case "create", "update", "insert": + // split key and payload + dataParts := bytes.SplitN(parts[2], []byte("|"), 2) + if len(dataParts) != 2 { + api.send(nil, dbMsgTypeError, "bad request: malformed message", nil) + return + } + + switch string(parts[1]) { + case "create": + // 128|create|| + go api.handlePut(parts[0], string(dataParts[0]), dataParts[1], true) + case "update": + // 129|update|| + go api.handlePut(parts[0], string(dataParts[0]), dataParts[1], false) + case "insert": + // 130|insert|| + go api.handleInsert(parts[0], string(dataParts[0]), dataParts[1]) + } + case "delete": + // 131|delete| + go api.handleDelete(parts[0], string(parts[2])) + default: + api.send(parts[0], dbMsgTypeError, "bad request: unknown method", nil) + } +} + +func (api *DatabaseAPI) send(opID []byte, msgType string, msgOrKey string, data []byte) { + c := container.New(opID) + c.Append(dbAPISeperatorBytes) + c.Append([]byte(msgType)) + + if msgOrKey != emptyString { + c.Append(dbAPISeperatorBytes) + c.Append([]byte(msgOrKey)) + } + + if len(data) > 0 { + c.Append(dbAPISeperatorBytes) + c.Append(data) + } + + api.sendBytes(c.CompileData()) +} + +func (api *DatabaseAPI) handleGet(opID []byte, key string) { + // 123|get| + // 123|ok|| + // 123|error| + + var data []byte + + r, err := api.db.Get(key) + if err == nil { + data, err = MarshalRecord(r, true) + } + if err != nil { + api.send(opID, dbMsgTypeError, err.Error(), nil) + return + } + api.send(opID, dbMsgTypeOk, r.Key(), data) +} + +func (api *DatabaseAPI) handleQuery(opID []byte, queryText string) { + // 124|query| + // 124|ok|| + // 124|done + // 124|warning| + // 124|error| + // 124|warning| // error with single record, operation continues + // 124|cancel + + var err error + + q, err := query.ParseQuery(queryText) + if err != nil { + api.send(opID, dbMsgTypeError, err.Error(), nil) + return + } + + api.processQuery(opID, q) +} + +func (api *DatabaseAPI) processQuery(opID []byte, q *query.Query) (ok bool) { + it, err := api.db.Query(q) + if err != nil { + api.send(opID, dbMsgTypeError, err.Error(), nil) + return false + } + + // Save query iterator. + api.queriesLock.Lock() + api.queries[string(opID)] = it + api.queriesLock.Unlock() + + // Remove query iterator after it ended. + defer func() { + api.queriesLock.Lock() + defer api.queriesLock.Unlock() + delete(api.queries, string(opID)) + }() + + for { + select { + case <-api.shutdownSignal: + // cancel query and return + it.Cancel() + return false + case r := <-it.Next: + // process query feed + if r != nil { + // process record + data, err := MarshalRecord(r, true) + if err != nil { + api.send(opID, dbMsgTypeWarning, err.Error(), nil) + continue + } + api.send(opID, dbMsgTypeOk, r.Key(), data) + } else { + // sub feed ended + if it.Err() != nil { + api.send(opID, dbMsgTypeError, it.Err().Error(), nil) + return false + } + api.send(opID, dbMsgTypeDone, emptyString, nil) + return true + } + } + } +} + +// func (api *DatabaseWebsocketAPI) runQuery() + +func (api *DatabaseAPI) handleSub(opID []byte, queryText string) { + // 125|sub| + // 125|upd|| + // 125|new|| + // 125|delete| + // 125|warning| // error with single record, operation continues + // 125|cancel + var err error + + q, err := query.ParseQuery(queryText) + if err != nil { + api.send(opID, dbMsgTypeError, err.Error(), nil) + return + } + + sub, ok := api.registerSub(opID, q) + if !ok { + return + } + api.processSub(opID, sub) +} + +func (api *DatabaseAPI) registerSub(opID []byte, q *query.Query) (sub *database.Subscription, ok bool) { + var err error + sub, err = api.db.Subscribe(q) + if err != nil { + api.send(opID, dbMsgTypeError, err.Error(), nil) + return nil, false + } + + return sub, true +} + +func (api *DatabaseAPI) processSub(opID []byte, sub *database.Subscription) { + // Save subscription. + api.subsLock.Lock() + api.subs[string(opID)] = sub + api.subsLock.Unlock() + + // Remove subscription after it ended. + defer func() { + api.subsLock.Lock() + defer api.subsLock.Unlock() + delete(api.subs, string(opID)) + }() + + for { + select { + case <-api.shutdownSignal: + // cancel sub and return + _ = sub.Cancel() + return + case r := <-sub.Feed: + // process sub feed + if r != nil { + // process record + data, err := MarshalRecord(r, true) + if err != nil { + api.send(opID, dbMsgTypeWarning, err.Error(), nil) + continue + } + // TODO: use upd, new and delete msgTypes + r.Lock() + isDeleted := r.Meta().IsDeleted() + isNew := r.Meta().Created == r.Meta().Modified + r.Unlock() + switch { + case isDeleted: + api.send(opID, dbMsgTypeDel, r.Key(), nil) + case isNew: + api.send(opID, dbMsgTypeNew, r.Key(), data) + default: + api.send(opID, dbMsgTypeUpd, r.Key(), data) + } + } else { + // sub feed ended + api.send(opID, dbMsgTypeDone, "", nil) + return + } + } + } +} + +func (api *DatabaseAPI) handleQsub(opID []byte, queryText string) { + // 127|qsub| + // 127|ok|| + // 127|done + // 127|error| + // 127|upd|| + // 127|new|| + // 127|delete| + // 127|warning| // error with single record, operation continues + // 127|cancel + + var err error + + q, err := query.ParseQuery(queryText) + if err != nil { + api.send(opID, dbMsgTypeError, err.Error(), nil) + return + } + + sub, ok := api.registerSub(opID, q) + if !ok { + return + } + ok = api.processQuery(opID, q) + if !ok { + return + } + api.processSub(opID, sub) +} + +func (api *DatabaseAPI) handleCancel(opID []byte) { + api.cancelQuery(opID) + api.cancelSub(opID) +} + +func (api *DatabaseAPI) cancelQuery(opID []byte) { + api.queriesLock.Lock() + defer api.queriesLock.Unlock() + + // Get subscription from api. + it, ok := api.queries[string(opID)] + if !ok { + // Fail silently as quries end by themselves when finished. + return + } + + // End query. + it.Cancel() + + // The query handler will end the communication with a done message. +} + +func (api *DatabaseAPI) cancelSub(opID []byte) { + api.subsLock.Lock() + defer api.subsLock.Unlock() + + // Get subscription from api. + sub, ok := api.subs[string(opID)] + if !ok { + api.send(opID, dbMsgTypeError, "could not find subscription", nil) + return + } + + // End subscription. + err := sub.Cancel() + if err != nil { + api.send(opID, dbMsgTypeError, fmt.Sprintf("failed to cancel subscription: %s", err), nil) + } + + // The subscription handler will end the communication with a done message. +} + +func (api *DatabaseAPI) handlePut(opID []byte, key string, data []byte, create bool) { + // 128|create|| + // 128|success + // 128|error| + + // 129|update|| + // 129|success + // 129|error| + + if len(data) < 2 { + api.send(opID, dbMsgTypeError, "bad request: malformed message", nil) + return + } + + // TODO - staged for deletion: remove transition code + // if data[0] != dsd.JSON { + // typedData := make([]byte, len(data)+1) + // typedData[0] = dsd.JSON + // copy(typedData[1:], data) + // data = typedData + // } + + r, err := record.NewWrapper(key, nil, data[0], data[1:]) + if err != nil { + api.send(opID, dbMsgTypeError, err.Error(), nil) + return + } + + if create { + err = api.db.PutNew(r) + } else { + err = api.db.Put(r) + } + if err != nil { + api.send(opID, dbMsgTypeError, err.Error(), nil) + return + } + api.send(opID, dbMsgTypeSuccess, emptyString, nil) +} + +func (api *DatabaseAPI) handleInsert(opID []byte, key string, data []byte) { + // 130|insert|| + // 130|success + // 130|error| + + r, err := api.db.Get(key) + if err != nil { + api.send(opID, dbMsgTypeError, err.Error(), nil) + return + } + + acc := r.GetAccessor(r) + + result := gjson.ParseBytes(data) + anythingPresent := false + var insertError error + result.ForEach(func(key gjson.Result, value gjson.Result) bool { + anythingPresent = true + if !key.Exists() { + insertError = errors.New("values must be in a map") + return false + } + if key.Type != gjson.String { + insertError = errors.New("keys must be strings") + return false + } + if !value.Exists() { + insertError = errors.New("non-existent value") + return false + } + insertError = acc.Set(key.String(), value.Value()) + return insertError == nil + }) + + if insertError != nil { + api.send(opID, dbMsgTypeError, insertError.Error(), nil) + return + } + if !anythingPresent { + api.send(opID, dbMsgTypeError, "could not find any valid values", nil) + return + } + + err = api.db.Put(r) + if err != nil { + api.send(opID, dbMsgTypeError, err.Error(), nil) + return + } + + api.send(opID, dbMsgTypeSuccess, emptyString, nil) +} + +func (api *DatabaseAPI) handleDelete(opID []byte, key string) { + // 131|delete| + // 131|success + // 131|error| + + err := api.db.Delete(key) + if err != nil { + api.send(opID, dbMsgTypeError, err.Error(), nil) + return + } + api.send(opID, dbMsgTypeSuccess, emptyString, nil) +} + +// MarshalRecord locks and marshals the given record, additionally adding +// metadata and returning it as json. +func MarshalRecord(r record.Record, withDSDIdentifier bool) ([]byte, error) { + r.Lock() + defer r.Unlock() + + // Pour record into JSON. + jsonData, err := r.Marshal(r, dsd.JSON) + if err != nil { + return nil, err + } + + // Remove JSON identifier for manual editing. + jsonData = bytes.TrimPrefix(jsonData, varint.Pack8(dsd.JSON)) + + // Add metadata. + jsonData, err = sjson.SetBytes(jsonData, "_meta", r.Meta()) + if err != nil { + return nil, err + } + + // Add database key. + jsonData, err = sjson.SetBytes(jsonData, "_meta.Key", r.Key()) + if err != nil { + return nil, err + } + + // Add JSON identifier again. + if withDSDIdentifier { + formatID := varint.Pack8(dsd.JSON) + finalData := make([]byte, 0, len(formatID)+len(jsonData)) + finalData = append(finalData, formatID...) + finalData = append(finalData, jsonData...) + return finalData, nil + } + return jsonData, nil +} diff --git a/base/api/doc.go b/base/api/doc.go new file mode 100644 index 000000000..d9a91bae0 --- /dev/null +++ b/base/api/doc.go @@ -0,0 +1,10 @@ +/* +Package api provides an API for integration with other components of the same software package and also third party components. + +It provides direct database access as well as a simpler way to register API endpoints. You can of course also register raw `http.Handler`s directly. + +Optional authentication guards registered handlers. This is achieved by attaching functions to the `http.Handler`s that are registered, which allow them to specify the required permissions for the handler. + +The permissions are divided into the roles and assume a single user per host. The Roles are User, Admin and Self. User roles are expected to have mostly read access and react to notifications or system events, like a system tray program. The Admin role is meant for advanced components that also change settings, but are restricted so they cannot break the software. Self is reserved for internal use with full access. +*/ +package api diff --git a/base/api/endpoints.go b/base/api/endpoints.go new file mode 100644 index 000000000..8f8769b88 --- /dev/null +++ b/base/api/endpoints.go @@ -0,0 +1,532 @@ +package api + +import ( + "bytes" + "errors" + "fmt" + "io" + "net/http" + "sort" + "strconv" + "strings" + "sync" + + "github.com/gorilla/mux" + + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" +) + +// Endpoint describes an API Endpoint. +// Path and at least one permission are required. +// As is exactly one function. +type Endpoint struct { //nolint:maligned + // Name is the human reabable name of the endpoint. + Name string + // Description is the human readable description and documentation of the endpoint. + Description string + // Parameters is the parameter documentation. + Parameters []Parameter `json:",omitempty"` + + // Path describes the URL path of the endpoint. + Path string + + // MimeType defines the content type of the returned data. + MimeType string + + // Read defines the required read permission. + Read Permission `json:",omitempty"` + + // ReadMethod sets the required read method for the endpoint. + // Available methods are: + // GET: Returns data only, no action is taken, nothing is changed. + // If omitted, defaults to GET. + // + // This field is currently being introduced and will only warn and not deny + // access if the write method does not match. + ReadMethod string `json:",omitempty"` + + // Write defines the required write permission. + Write Permission `json:",omitempty"` + + // WriteMethod sets the required write method for the endpoint. + // Available methods are: + // POST: Create a new resource; Change a status; Execute a function + // PUT: Update an existing resource + // DELETE: Remove an existing resource + // If omitted, defaults to POST. + // + // This field is currently being introduced and will only warn and not deny + // access if the write method does not match. + WriteMethod string `json:",omitempty"` + + // BelongsTo defines which module this endpoint belongs to. + // The endpoint will not be accessible if the module is not online. + BelongsTo *modules.Module `json:"-"` + + // ActionFunc is for simple actions with a return message for the user. + ActionFunc ActionFunc `json:"-"` + + // DataFunc is for returning raw data that the caller for further processing. + DataFunc DataFunc `json:"-"` + + // StructFunc is for returning any kind of struct. + StructFunc StructFunc `json:"-"` + + // RecordFunc is for returning a database record. It will be properly locked + // and marshalled including metadata. + RecordFunc RecordFunc `json:"-"` + + // HandlerFunc is the raw http handler. + HandlerFunc http.HandlerFunc `json:"-"` +} + +// Parameter describes a parameterized variation of an endpoint. +type Parameter struct { + Method string + Field string + Value string + Description string +} + +// HTTPStatusProvider is an interface for errors to provide a custom HTTP +// status code. +type HTTPStatusProvider interface { + HTTPStatus() int +} + +// HTTPStatusError represents an error with an HTTP status code. +type HTTPStatusError struct { + err error + code int +} + +// Error returns the error message. +func (e *HTTPStatusError) Error() string { + return e.err.Error() +} + +// Unwrap return the wrapped error. +func (e *HTTPStatusError) Unwrap() error { + return e.err +} + +// HTTPStatus returns the HTTP status code this error. +func (e *HTTPStatusError) HTTPStatus() int { + return e.code +} + +// ErrorWithStatus adds the HTTP status code to the error. +func ErrorWithStatus(err error, code int) error { + return &HTTPStatusError{ + err: err, + code: code, + } +} + +type ( + // ActionFunc is for simple actions with a return message for the user. + ActionFunc func(ar *Request) (msg string, err error) + + // DataFunc is for returning raw data that the caller for further processing. + DataFunc func(ar *Request) (data []byte, err error) + + // StructFunc is for returning any kind of struct. + StructFunc func(ar *Request) (i interface{}, err error) + + // RecordFunc is for returning a database record. It will be properly locked + // and marshalled including metadata. + RecordFunc func(ar *Request) (r record.Record, err error) +) + +// MIME Types. +const ( + MimeTypeJSON string = "application/json" + MimeTypeText string = "text/plain" + + apiV1Path = "/api/v1/" +) + +func init() { + RegisterHandler(apiV1Path+"{endpointPath:.+}", &endpointHandler{}) +} + +var ( + endpoints = make(map[string]*Endpoint) + endpointsMux = mux.NewRouter() + endpointsLock sync.RWMutex + + // ErrInvalidEndpoint is returned when an invalid endpoint is registered. + ErrInvalidEndpoint = errors.New("endpoint is invalid") + + // ErrAlreadyRegistered is returned when there already is an endpoint with + // the same path registered. + ErrAlreadyRegistered = errors.New("an endpoint for this path is already registered") +) + +func getAPIContext(r *http.Request) (apiEndpoint *Endpoint, apiRequest *Request) { + // Get request context and check if we already have an action cached. + apiRequest = GetAPIRequest(r) + if apiRequest == nil { + return nil, nil + } + var ok bool + apiEndpoint, ok = apiRequest.HandlerCache.(*Endpoint) + if ok { + return apiEndpoint, apiRequest + } + + endpointsLock.RLock() + defer endpointsLock.RUnlock() + + // Get handler for request. + // Gorilla does not support handling this on our own very well. + // See github.com/gorilla/mux.ServeHTTP for reference. + var match mux.RouteMatch + var handler http.Handler + if endpointsMux.Match(r, &match) { + handler = match.Handler + apiRequest.Route = match.Route + // Add/Override variables instead of replacing. + for k, v := range match.Vars { + apiRequest.URLVars[k] = v + } + } else { + return nil, apiRequest + } + + apiEndpoint, ok = handler.(*Endpoint) + if ok { + // Cache for next operation. + apiRequest.HandlerCache = apiEndpoint + } + return apiEndpoint, apiRequest +} + +// RegisterEndpoint registers a new endpoint. An error will be returned if it +// does not pass the sanity checks. +func RegisterEndpoint(e Endpoint) error { + if err := e.check(); err != nil { + return fmt.Errorf("%w: %w", ErrInvalidEndpoint, err) + } + + endpointsLock.Lock() + defer endpointsLock.Unlock() + + _, ok := endpoints[e.Path] + if ok { + return ErrAlreadyRegistered + } + + endpoints[e.Path] = &e + endpointsMux.Handle(apiV1Path+e.Path, &e) + return nil +} + +// GetEndpointByPath returns the endpoint registered with the given path. +func GetEndpointByPath(path string) (*Endpoint, error) { + endpointsLock.Lock() + defer endpointsLock.Unlock() + endpoint, ok := endpoints[path] + if !ok { + return nil, fmt.Errorf("no registered endpoint on path: %q", path) + } + + return endpoint, nil +} + +func (e *Endpoint) check() error { + // Check path. + if strings.TrimSpace(e.Path) == "" { + return errors.New("path is missing") + } + + // Check permissions. + if e.Read < Dynamic || e.Read > PermitSelf { + return errors.New("invalid read permission") + } + if e.Write < Dynamic || e.Write > PermitSelf { + return errors.New("invalid write permission") + } + + // Check methods. + if e.Read != NotSupported { + switch e.ReadMethod { + case http.MethodGet: + // All good. + case "": + // Set to default. + e.ReadMethod = http.MethodGet + default: + return errors.New("invalid read method") + } + } else { + e.ReadMethod = "" + } + if e.Write != NotSupported { + switch e.WriteMethod { + case http.MethodPost, + http.MethodPut, + http.MethodDelete: + // All good. + case "": + // Set to default. + e.WriteMethod = http.MethodPost + default: + return errors.New("invalid write method") + } + } else { + e.WriteMethod = "" + } + + // Check functions. + var defaultMimeType string + fnCnt := 0 + if e.ActionFunc != nil { + fnCnt++ + defaultMimeType = MimeTypeText + } + if e.DataFunc != nil { + fnCnt++ + defaultMimeType = MimeTypeText + } + if e.StructFunc != nil { + fnCnt++ + defaultMimeType = MimeTypeJSON + } + if e.RecordFunc != nil { + fnCnt++ + defaultMimeType = MimeTypeJSON + } + if e.HandlerFunc != nil { + fnCnt++ + defaultMimeType = MimeTypeText + } + if fnCnt != 1 { + return errors.New("only one function may be set") + } + + // Set default mime type. + if e.MimeType == "" { + e.MimeType = defaultMimeType + } + + return nil +} + +// ExportEndpoints exports the registered endpoints. The returned data must be +// treated as immutable. +func ExportEndpoints() []*Endpoint { + endpointsLock.RLock() + defer endpointsLock.RUnlock() + + // Copy the map into a slice. + eps := make([]*Endpoint, 0, len(endpoints)) + for _, ep := range endpoints { + eps = append(eps, ep) + } + + sort.Sort(sortByPath(eps)) + return eps +} + +type sortByPath []*Endpoint + +func (eps sortByPath) Len() int { return len(eps) } +func (eps sortByPath) Less(i, j int) bool { return eps[i].Path < eps[j].Path } +func (eps sortByPath) Swap(i, j int) { eps[i], eps[j] = eps[j], eps[i] } + +type endpointHandler struct{} + +var _ AuthenticatedHandler = &endpointHandler{} // Compile time interface check. + +// ReadPermission returns the read permission for the handler. +func (eh *endpointHandler) ReadPermission(r *http.Request) Permission { + apiEndpoint, _ := getAPIContext(r) + if apiEndpoint != nil { + return apiEndpoint.Read + } + return NotFound +} + +// WritePermission returns the write permission for the handler. +func (eh *endpointHandler) WritePermission(r *http.Request) Permission { + apiEndpoint, _ := getAPIContext(r) + if apiEndpoint != nil { + return apiEndpoint.Write + } + return NotFound +} + +// ServeHTTP handles the http request. +func (eh *endpointHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + apiEndpoint, apiRequest := getAPIContext(r) + if apiEndpoint == nil || apiRequest == nil { + http.NotFound(w, r) + return + } + + apiEndpoint.ServeHTTP(w, r) +} + +// ServeHTTP handles the http request. +func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) { + _, apiRequest := getAPIContext(r) + if apiRequest == nil { + http.NotFound(w, r) + return + } + + // Wait for the owning module to be ready. + if !moduleIsReady(e.BelongsTo) { + http.Error(w, "The API endpoint is not ready yet or the its module is not enabled. Reload (F5) to try again.", http.StatusServiceUnavailable) + return + } + + // Return OPTIONS request before starting to handle normal requests. + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + eMethod, readMethod, ok := getEffectiveMethod(r) + if !ok { + http.Error(w, "unsupported method for the actions API", http.StatusMethodNotAllowed) + return + } + + if readMethod { + if eMethod != e.ReadMethod { + log.Tracer(r.Context()).Warningf( + "api: method %q does not match required read method %q%s", + r.Method, + e.ReadMethod, + " - this will be an error and abort the request in the future", + ) + } + } else { + if eMethod != e.WriteMethod { + log.Tracer(r.Context()).Warningf( + "api: method %q does not match required write method %q%s", + r.Method, + e.WriteMethod, + " - this will be an error and abort the request in the future", + ) + } + } + + switch eMethod { + case http.MethodGet, http.MethodDelete: + // Nothing to do for these. + case http.MethodPost, http.MethodPut: + // Read body data. + inputData, ok := readBody(w, r) + if !ok { + return + } + apiRequest.InputData = inputData + + // restore request body for any http.HandlerFunc below + r.Body = io.NopCloser(bytes.NewReader(inputData)) + default: + // Defensive. + http.Error(w, "unsupported method for the actions API", http.StatusMethodNotAllowed) + return + } + + // Add response headers to request struct so that the endpoint can work with them. + apiRequest.ResponseHeader = w.Header() + + // Execute action function and get response data + var responseData []byte + var err error + + switch { + case e.ActionFunc != nil: + var msg string + msg, err = e.ActionFunc(apiRequest) + if !strings.HasSuffix(msg, "\n") { + msg += "\n" + } + if err == nil { + responseData = []byte(msg) + } + + case e.DataFunc != nil: + responseData, err = e.DataFunc(apiRequest) + + case e.StructFunc != nil: + var v interface{} + v, err = e.StructFunc(apiRequest) + if err == nil && v != nil { + var mimeType string + responseData, mimeType, _, err = dsd.MimeDump(v, r.Header.Get("Accept")) + if err == nil { + w.Header().Set("Content-Type", mimeType) + } + } + + case e.RecordFunc != nil: + var rec record.Record + rec, err = e.RecordFunc(apiRequest) + if err == nil && r != nil { + responseData, err = MarshalRecord(rec, false) + } + + case e.HandlerFunc != nil: + e.HandlerFunc(w, r) + return + + default: + http.Error(w, "missing handler", http.StatusInternalServerError) + return + } + + // Check for handler error. + if err != nil { + var statusProvider HTTPStatusProvider + if errors.As(err, &statusProvider) { + http.Error(w, err.Error(), statusProvider.HTTPStatus()) + } else { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + } + + // Return no content if there is none, or if request is HEAD. + if len(responseData) == 0 || r.Method == http.MethodHead { + w.WriteHeader(http.StatusNoContent) + return + } + + // Set content type if not yet set. + if w.Header().Get("Content-Type") == "" { + w.Header().Set("Content-Type", e.MimeType+"; charset=utf-8") + } + + // Write response. + w.Header().Set("Content-Length", strconv.Itoa(len(responseData))) + w.WriteHeader(http.StatusOK) + _, err = w.Write(responseData) + if err != nil { + log.Tracer(r.Context()).Warningf("api: failed to write response: %s", err) + } +} + +func readBody(w http.ResponseWriter, r *http.Request) (inputData []byte, ok bool) { + // Check for too long content in order to prevent death. + if r.ContentLength > 20000000 { // 20MB + http.Error(w, "too much input data", http.StatusRequestEntityTooLarge) + return nil, false + } + + // Read and close body. + inputData, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "failed to read body"+err.Error(), http.StatusInternalServerError) + return nil, false + } + return inputData, true +} diff --git a/base/api/endpoints_config.go b/base/api/endpoints_config.go new file mode 100644 index 000000000..42ce4567d --- /dev/null +++ b/base/api/endpoints_config.go @@ -0,0 +1,24 @@ +package api + +import ( + "github.com/safing/portmaster/base/config" +) + +func registerConfigEndpoints() error { + if err := RegisterEndpoint(Endpoint{ + Path: "config/options", + Read: PermitAnyone, + MimeType: MimeTypeJSON, + StructFunc: listConfig, + Name: "Export Configuration Options", + Description: "Returns a list of all registered configuration options and their metadata. This does not include the current active or default settings.", + }); err != nil { + return err + } + + return nil +} + +func listConfig(ar *Request) (i interface{}, err error) { + return config.ExportOptions(), nil +} diff --git a/base/api/endpoints_debug.go b/base/api/endpoints_debug.go new file mode 100644 index 000000000..dc1ba6ff5 --- /dev/null +++ b/base/api/endpoints_debug.go @@ -0,0 +1,256 @@ +package api + +import ( + "bytes" + "context" + "errors" + "fmt" + "net/http" + "os" + "runtime/pprof" + "strings" + "time" + + "github.com/safing/portmaster/base/info" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/utils/debug" +) + +func registerDebugEndpoints() error { + if err := RegisterEndpoint(Endpoint{ + Path: "ping", + Read: PermitAnyone, + ActionFunc: ping, + Name: "Ping", + Description: "Pong.", + }); err != nil { + return err + } + + if err := RegisterEndpoint(Endpoint{ + Path: "ready", + Read: PermitAnyone, + ActionFunc: ready, + Name: "Ready", + Description: "Check if Portmaster has completed starting and is ready.", + }); err != nil { + return err + } + + if err := RegisterEndpoint(Endpoint{ + Path: "debug/stack", + Read: PermitAnyone, + DataFunc: getStack, + Name: "Get Goroutine Stack", + Description: "Returns the current goroutine stack.", + }); err != nil { + return err + } + + if err := RegisterEndpoint(Endpoint{ + Path: "debug/stack/print", + Read: PermitAnyone, + ActionFunc: printStack, + Name: "Print Goroutine Stack", + Description: "Prints the current goroutine stack to stdout.", + }); err != nil { + return err + } + + if err := RegisterEndpoint(Endpoint{ + Path: "debug/cpu", + MimeType: "application/octet-stream", + Read: PermitAnyone, + DataFunc: handleCPUProfile, + Name: "Get CPU Profile", + Description: strings.ReplaceAll(`Gather and return the CPU profile. +This data needs to gathered over a period of time, which is specified using the duration parameter. + +You can easily view this data in your browser with this command (with Go installed): +"go tool pprof -http :8888 http://127.0.0.1:817/api/v1/debug/cpu" +`, `"`, "`"), + Parameters: []Parameter{{ + Method: http.MethodGet, + Field: "duration", + Value: "10s", + Description: "Specify the formatting style. The default is simple markdown formatting.", + }}, + }); err != nil { + return err + } + + if err := RegisterEndpoint(Endpoint{ + Path: "debug/heap", + MimeType: "application/octet-stream", + Read: PermitAnyone, + DataFunc: handleHeapProfile, + Name: "Get Heap Profile", + Description: strings.ReplaceAll(`Gather and return the heap memory profile. + + You can easily view this data in your browser with this command (with Go installed): + "go tool pprof -http :8888 http://127.0.0.1:817/api/v1/debug/heap" + `, `"`, "`"), + }); err != nil { + return err + } + + if err := RegisterEndpoint(Endpoint{ + Path: "debug/allocs", + MimeType: "application/octet-stream", + Read: PermitAnyone, + DataFunc: handleAllocsProfile, + Name: "Get Allocs Profile", + Description: strings.ReplaceAll(`Gather and return the memory allocation profile. + + You can easily view this data in your browser with this command (with Go installed): + "go tool pprof -http :8888 http://127.0.0.1:817/api/v1/debug/allocs" + `, `"`, "`"), + }); err != nil { + return err + } + + if err := RegisterEndpoint(Endpoint{ + Path: "debug/info", + Read: PermitAnyone, + DataFunc: debugInfo, + Name: "Get Debug Information", + Description: "Returns debugging information, including the version and platform info, errors, logs and the current goroutine stack.", + Parameters: []Parameter{{ + Method: http.MethodGet, + Field: "style", + Value: "github", + Description: "Specify the formatting style. The default is simple markdown formatting.", + }}, + }); err != nil { + return err + } + + return nil +} + +// ping responds with pong. +func ping(ar *Request) (msg string, err error) { + // TODO: Remove upgrade to "ready" when all UI components have transitioned. + if modules.IsStarting() || modules.IsShuttingDown() { + return "", ErrorWithStatus(errors.New("portmaster is not ready, reload (F5) to try again"), http.StatusTooEarly) + } + + return "Pong.", nil +} + +// ready checks if Portmaster has completed starting. +func ready(ar *Request) (msg string, err error) { + if modules.IsStarting() || modules.IsShuttingDown() { + return "", ErrorWithStatus(errors.New("portmaster is not ready, reload (F5) to try again"), http.StatusTooEarly) + } + return "Portmaster is ready.", nil +} + +// getStack returns the current goroutine stack. +func getStack(_ *Request) (data []byte, err error) { + buf := &bytes.Buffer{} + err = pprof.Lookup("goroutine").WriteTo(buf, 1) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// printStack prints the current goroutine stack to stderr. +func printStack(_ *Request) (msg string, err error) { + _, err = fmt.Fprint(os.Stderr, "===== PRINTING STACK =====\n") + if err == nil { + err = pprof.Lookup("goroutine").WriteTo(os.Stderr, 1) + } + if err == nil { + _, err = fmt.Fprint(os.Stderr, "===== END OF STACK =====\n") + } + if err != nil { + return "", err + } + return "stack printed to stdout", nil +} + +// handleCPUProfile returns the CPU profile. +func handleCPUProfile(ar *Request) (data []byte, err error) { + // Parse duration. + duration := 10 * time.Second + if durationOption := ar.Request.URL.Query().Get("duration"); durationOption != "" { + parsedDuration, err := time.ParseDuration(durationOption) + if err != nil { + return nil, fmt.Errorf("failed to parse duration: %w", err) + } + duration = parsedDuration + } + + // Indicate download and filename. + ar.ResponseHeader.Set( + "Content-Disposition", + fmt.Sprintf(`attachment; filename="portmaster-cpu-profile_v%s.pprof"`, info.Version()), + ) + + // Start CPU profiling. + buf := new(bytes.Buffer) + if err := pprof.StartCPUProfile(buf); err != nil { + return nil, fmt.Errorf("failed to start cpu profile: %w", err) + } + + // Wait for the specified duration. + select { + case <-time.After(duration): + case <-ar.Context().Done(): + pprof.StopCPUProfile() + return nil, context.Canceled + } + + // Stop CPU profiling and return data. + pprof.StopCPUProfile() + return buf.Bytes(), nil +} + +// handleHeapProfile returns the Heap profile. +func handleHeapProfile(ar *Request) (data []byte, err error) { + // Indicate download and filename. + ar.ResponseHeader.Set( + "Content-Disposition", + fmt.Sprintf(`attachment; filename="portmaster-memory-heap-profile_v%s.pprof"`, info.Version()), + ) + + buf := new(bytes.Buffer) + if err := pprof.Lookup("heap").WriteTo(buf, 0); err != nil { + return nil, fmt.Errorf("failed to write heap profile: %w", err) + } + return buf.Bytes(), nil +} + +// handleAllocsProfile returns the Allocs profile. +func handleAllocsProfile(ar *Request) (data []byte, err error) { + // Indicate download and filename. + ar.ResponseHeader.Set( + "Content-Disposition", + fmt.Sprintf(`attachment; filename="portmaster-memory-allocs-profile_v%s.pprof"`, info.Version()), + ) + + buf := new(bytes.Buffer) + if err := pprof.Lookup("allocs").WriteTo(buf, 0); err != nil { + return nil, fmt.Errorf("failed to write allocs profile: %w", err) + } + return buf.Bytes(), nil +} + +// debugInfo returns the debugging information for support requests. +func debugInfo(ar *Request) (data []byte, err error) { + // Create debug information helper. + di := new(debug.Info) + di.Style = ar.Request.URL.Query().Get("style") + + // Add debug information. + di.AddVersionInfo() + di.AddPlatformInfo(ar.Context()) + di.AddLastReportedModuleError() + di.AddLastUnexpectedLogs() + di.AddGoroutineStack() + + // Return data. + return di.Bytes(), nil +} diff --git a/base/api/endpoints_meta.go b/base/api/endpoints_meta.go new file mode 100644 index 000000000..74019fa72 --- /dev/null +++ b/base/api/endpoints_meta.go @@ -0,0 +1,140 @@ +package api + +import ( + "encoding/json" + "errors" + "net/http" +) + +func registerMetaEndpoints() error { + if err := RegisterEndpoint(Endpoint{ + Path: "endpoints", + Read: PermitAnyone, + MimeType: MimeTypeJSON, + DataFunc: listEndpoints, + Name: "Export API Endpoints", + Description: "Returns a list of all registered endpoints and their metadata.", + }); err != nil { + return err + } + + if err := RegisterEndpoint(Endpoint{ + Path: "auth/permissions", + Read: Dynamic, + StructFunc: permissions, + Name: "View Current Permissions", + Description: "Returns the current permissions assigned to the request.", + }); err != nil { + return err + } + + if err := RegisterEndpoint(Endpoint{ + Path: "auth/bearer", + Read: Dynamic, + HandlerFunc: authBearer, + Name: "Request HTTP Bearer Auth", + Description: "Returns an HTTP Bearer Auth request, if not authenticated.", + }); err != nil { + return err + } + + if err := RegisterEndpoint(Endpoint{ + Path: "auth/basic", + Read: Dynamic, + HandlerFunc: authBasic, + Name: "Request HTTP Basic Auth", + Description: "Returns an HTTP Basic Auth request, if not authenticated.", + }); err != nil { + return err + } + + if err := RegisterEndpoint(Endpoint{ + Path: "auth/reset", + Read: PermitAnyone, + HandlerFunc: authReset, + Name: "Reset Authenticated Session", + Description: "Resets authentication status internally and in the browser.", + }); err != nil { + return err + } + + return nil +} + +func listEndpoints(ar *Request) (data []byte, err error) { + data, err = json.Marshal(ExportEndpoints()) + return +} + +func permissions(ar *Request) (i interface{}, err error) { + if ar.AuthToken == nil { + return nil, errors.New("authentication token missing") + } + + return struct { + Read Permission + Write Permission + ReadRole string + WriteRole string + }{ + Read: ar.AuthToken.Read, + Write: ar.AuthToken.Write, + ReadRole: ar.AuthToken.Read.Role(), + WriteRole: ar.AuthToken.Write.Role(), + }, nil +} + +func authBearer(w http.ResponseWriter, r *http.Request) { + // Check if authenticated by checking read permission. + ar := GetAPIRequest(r) + if ar.AuthToken.Read != PermitAnyone { + TextResponse(w, r, "Authenticated.") + return + } + + // Respond with desired authentication header. + w.Header().Set( + "WWW-Authenticate", + `Bearer realm="Portmaster API" domain="/"`, + ) + http.Error(w, "Authorization required.", http.StatusUnauthorized) +} + +func authBasic(w http.ResponseWriter, r *http.Request) { + // Check if authenticated by checking read permission. + ar := GetAPIRequest(r) + if ar.AuthToken.Read != PermitAnyone { + TextResponse(w, r, "Authenticated.") + return + } + + // Respond with desired authentication header. + w.Header().Set( + "WWW-Authenticate", + `Basic realm="Portmaster API" domain="/"`, + ) + http.Error(w, "Authorization required.", http.StatusUnauthorized) +} + +func authReset(w http.ResponseWriter, r *http.Request) { + // Get session cookie from request and delete session if exists. + c, err := r.Cookie(sessionCookieName) + if err == nil { + deleteSession(c.Value) + } + + // Delete session and cookie. + http.SetCookie(w, &http.Cookie{ + Name: sessionCookieName, + MaxAge: -1, // MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0' + }) + + // Request client to also reset all data. + w.Header().Set("Clear-Site-Data", "*") + + // Set HTTP Auth Realm without requesting authorization. + w.Header().Set("WWW-Authenticate", `None realm="Portmaster API"`) + + // Reply with 401 Unauthorized in order to clear HTTP Basic Auth data. + http.Error(w, "Session deleted.", http.StatusUnauthorized) +} diff --git a/base/api/endpoints_modules.go b/base/api/endpoints_modules.go new file mode 100644 index 000000000..22f6af3a8 --- /dev/null +++ b/base/api/endpoints_modules.go @@ -0,0 +1,56 @@ +package api + +import ( + "errors" + "fmt" + + "github.com/safing/portmaster/base/modules" +) + +func registerModulesEndpoints() error { + if err := RegisterEndpoint(Endpoint{ + Path: "modules/status", + Read: PermitUser, + StructFunc: getStatusfunc, + Name: "Get Module Status", + Description: "Returns status information of all modules.", + }); err != nil { + return err + } + + if err := RegisterEndpoint(Endpoint{ + Path: "modules/{moduleName:.+}/trigger/{eventName:.+}", + Write: PermitSelf, + ActionFunc: triggerEvent, + Name: "Trigger Event", + Description: "Triggers an event of an internal module.", + }); err != nil { + return err + } + + return nil +} + +func getStatusfunc(ar *Request) (i interface{}, err error) { + status := modules.GetStatus() + if status == nil { + return nil, errors.New("modules not yet initialized") + } + return status, nil +} + +func triggerEvent(ar *Request) (msg string, err error) { + // Get parameters. + moduleName := ar.URLVars["moduleName"] + eventName := ar.URLVars["eventName"] + if moduleName == "" || eventName == "" { + return "", errors.New("invalid parameters") + } + + // Inject event. + if err := module.InjectEvent("api event injection", moduleName, eventName, nil); err != nil { + return "", fmt.Errorf("failed to inject event: %w", err) + } + + return "event successfully injected", nil +} diff --git a/base/api/endpoints_test.go b/base/api/endpoints_test.go new file mode 100644 index 000000000..d24ead0ca --- /dev/null +++ b/base/api/endpoints_test.go @@ -0,0 +1,161 @@ +package api + +import ( + "errors" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/safing/portmaster/base/database/record" +) + +const ( + successMsg = "endpoint api success" + failedMsg = "endpoint api failed" +) + +type actionTestRecord struct { + record.Base + sync.Mutex + Msg string +} + +func TestEndpoints(t *testing.T) { + t.Parallel() + + testHandler := &mainHandler{ + mux: mainMux, + } + + // ActionFn + + assert.NoError(t, RegisterEndpoint(Endpoint{ + Path: "test/action", + Read: PermitAnyone, + ActionFunc: func(_ *Request) (msg string, err error) { + return successMsg, nil + }, + })) + assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/action", nil, successMsg) + + assert.NoError(t, RegisterEndpoint(Endpoint{ + Path: "test/action-err", + Read: PermitAnyone, + ActionFunc: func(_ *Request) (msg string, err error) { + return "", errors.New(failedMsg) + }, + })) + assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/action-err", nil, failedMsg) + + // DataFn + + assert.NoError(t, RegisterEndpoint(Endpoint{ + Path: "test/data", + Read: PermitAnyone, + DataFunc: func(_ *Request) (data []byte, err error) { + return []byte(successMsg), nil + }, + })) + assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/data", nil, successMsg) + + assert.NoError(t, RegisterEndpoint(Endpoint{ + Path: "test/data-err", + Read: PermitAnyone, + DataFunc: func(_ *Request) (data []byte, err error) { + return nil, errors.New(failedMsg) + }, + })) + assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/data-err", nil, failedMsg) + + // StructFn + + assert.NoError(t, RegisterEndpoint(Endpoint{ + Path: "test/struct", + Read: PermitAnyone, + StructFunc: func(_ *Request) (i interface{}, err error) { + return &actionTestRecord{ + Msg: successMsg, + }, nil + }, + })) + assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/struct", nil, successMsg) + + assert.NoError(t, RegisterEndpoint(Endpoint{ + Path: "test/struct-err", + Read: PermitAnyone, + StructFunc: func(_ *Request) (i interface{}, err error) { + return nil, errors.New(failedMsg) + }, + })) + assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/struct-err", nil, failedMsg) + + // RecordFn + + assert.NoError(t, RegisterEndpoint(Endpoint{ + Path: "test/record", + Read: PermitAnyone, + RecordFunc: func(_ *Request) (r record.Record, err error) { + r = &actionTestRecord{ + Msg: successMsg, + } + r.CreateMeta() + return r, nil + }, + })) + assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/record", nil, successMsg) + + assert.NoError(t, RegisterEndpoint(Endpoint{ + Path: "test/record-err", + Read: PermitAnyone, + RecordFunc: func(_ *Request) (r record.Record, err error) { + return nil, errors.New(failedMsg) + }, + })) + assert.HTTPBodyContains(t, testHandler.ServeHTTP, "GET", apiV1Path+"test/record-err", nil, failedMsg) +} + +func TestActionRegistration(t *testing.T) { + t.Parallel() + + assert.Error(t, RegisterEndpoint(Endpoint{})) + + assert.Error(t, RegisterEndpoint(Endpoint{ + Path: "test/err", + Read: NotFound, + })) + assert.Error(t, RegisterEndpoint(Endpoint{ + Path: "test/err", + Read: PermitSelf + 1, + })) + + assert.Error(t, RegisterEndpoint(Endpoint{ + Path: "test/err", + Write: NotFound, + })) + assert.Error(t, RegisterEndpoint(Endpoint{ + Path: "test/err", + Write: PermitSelf + 1, + })) + + assert.Error(t, RegisterEndpoint(Endpoint{ + Path: "test/err", + })) + + assert.Error(t, RegisterEndpoint(Endpoint{ + Path: "test/err", + ActionFunc: func(_ *Request) (msg string, err error) { + return successMsg, nil + }, + DataFunc: func(_ *Request) (data []byte, err error) { + return []byte(successMsg), nil + }, + })) + + assert.NoError(t, RegisterEndpoint(Endpoint{ + Path: "test/err", + ActionFunc: func(_ *Request) (msg string, err error) { + return successMsg, nil + }, + })) +} diff --git a/base/api/enriched-response.go b/base/api/enriched-response.go new file mode 100644 index 000000000..3caecf506 --- /dev/null +++ b/base/api/enriched-response.go @@ -0,0 +1,68 @@ +package api + +import ( + "bufio" + "errors" + "net" + "net/http" + + "github.com/safing/portmaster/base/log" +) + +// LoggingResponseWriter is a wrapper for http.ResponseWriter for better request logging. +type LoggingResponseWriter struct { + ResponseWriter http.ResponseWriter + Request *http.Request + Status int +} + +// NewLoggingResponseWriter wraps a http.ResponseWriter. +func NewLoggingResponseWriter(w http.ResponseWriter, r *http.Request) *LoggingResponseWriter { + return &LoggingResponseWriter{ + ResponseWriter: w, + Request: r, + } +} + +// Header wraps the original Header method. +func (lrw *LoggingResponseWriter) Header() http.Header { + return lrw.ResponseWriter.Header() +} + +// Write wraps the original Write method. +func (lrw *LoggingResponseWriter) Write(b []byte) (int, error) { + return lrw.ResponseWriter.Write(b) +} + +// WriteHeader wraps the original WriteHeader method to extract information. +func (lrw *LoggingResponseWriter) WriteHeader(code int) { + lrw.Status = code + lrw.ResponseWriter.WriteHeader(code) +} + +// Hijack wraps the original Hijack method, if available. +func (lrw *LoggingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hijacker, ok := lrw.ResponseWriter.(http.Hijacker) + if ok { + c, b, err := hijacker.Hijack() + if err != nil { + return nil, nil, err + } + log.Tracer(lrw.Request.Context()).Infof("api request: %s HIJ %s", lrw.Request.RemoteAddr, lrw.Request.RequestURI) + return c, b, nil + } + return nil, nil, errors.New("response does not implement http.Hijacker") +} + +// RequestLogger is a logging middleware. +func RequestLogger(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Tracer(r.Context()).Tracef("api request: %s ___ %s", r.RemoteAddr, r.RequestURI) + lrw := NewLoggingResponseWriter(w, r) + next.ServeHTTP(lrw, r) + if lrw.Status != 0 { + // request may have been hijacked + log.Tracer(r.Context()).Infof("api request: %s %d %s", lrw.Request.RemoteAddr, lrw.Status, lrw.Request.RequestURI) + } + }) +} diff --git a/base/api/main.go b/base/api/main.go new file mode 100644 index 000000000..687130ff0 --- /dev/null +++ b/base/api/main.go @@ -0,0 +1,88 @@ +package api + +import ( + "encoding/json" + "errors" + "flag" + "os" + "time" + + "github.com/safing/portmaster/base/modules" +) + +var ( + module *modules.Module + + exportEndpoints bool +) + +// API Errors. +var ( + ErrAuthenticationAlreadySet = errors.New("the authentication function has already been set") + ErrAuthenticationImmutable = errors.New("the authentication function can only be set before the api has started") +) + +func init() { + module = modules.Register("api", prep, start, stop, "database", "config") + + flag.BoolVar(&exportEndpoints, "export-api-endpoints", false, "export api endpoint registry and exit") +} + +func prep() error { + if exportEndpoints { + modules.SetCmdLineOperation(exportEndpointsCmd) + } + + if getDefaultListenAddress() == "" { + return errors.New("no default listen address for api available") + } + + if err := registerConfig(); err != nil { + return err + } + + if err := registerDebugEndpoints(); err != nil { + return err + } + + if err := registerConfigEndpoints(); err != nil { + return err + } + + if err := registerModulesEndpoints(); err != nil { + return err + } + + return registerMetaEndpoints() +} + +func start() error { + startServer() + + _ = updateAPIKeys(module.Ctx, nil) + err := module.RegisterEventHook("config", "config change", "update API keys", updateAPIKeys) + if err != nil { + return err + } + + // start api auth token cleaner + if authFnSet.IsSet() { + module.NewTask("clean api sessions", cleanSessions).Repeat(5 * time.Minute) + } + + return registerEndpointBridgeDB() +} + +func stop() error { + return stopServer() +} + +func exportEndpointsCmd() error { + data, err := json.MarshalIndent(ExportEndpoints(), "", " ") + if err != nil { + return err + } + + _, err = os.Stdout.Write(data) + return err +} diff --git a/base/api/main_test.go b/base/api/main_test.go new file mode 100644 index 000000000..df06dc631 --- /dev/null +++ b/base/api/main_test.go @@ -0,0 +1,56 @@ +package api + +import ( + "fmt" + "os" + "testing" + + // API depends on the database for the database api. + _ "github.com/safing/portmaster/base/database/dbmodule" + "github.com/safing/portmaster/base/dataroot" + "github.com/safing/portmaster/base/modules" +) + +func init() { + defaultListenAddress = "127.0.0.1:8817" +} + +func TestMain(m *testing.M) { + // enable module for testing + module.Enable() + + // tmp dir for data root (db & config) + tmpDir, err := os.MkdirTemp("", "portbase-testing-") + if err != nil { + fmt.Fprintf(os.Stderr, "failed to create tmp dir: %s\n", err) + os.Exit(1) + } + // initialize data dir + err = dataroot.Initialize(tmpDir, 0o0755) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to initialize data root: %s\n", err) + os.Exit(1) + } + + // start modules + var exitCode int + err = modules.Start() + if err != nil { + // starting failed + fmt.Fprintf(os.Stderr, "failed to setup test: %s\n", err) + exitCode = 1 + } else { + // run tests + exitCode = m.Run() + } + + // shutdown + _ = modules.Shutdown() + if modules.GetExitStatusCode() != 0 { + exitCode = modules.GetExitStatusCode() + fmt.Fprintf(os.Stderr, "failed to cleanly shutdown test: %s\n", err) + } + // clean up and exit + _ = os.RemoveAll(tmpDir) + os.Exit(exitCode) +} diff --git a/base/api/modules.go b/base/api/modules.go new file mode 100644 index 000000000..c5c366db8 --- /dev/null +++ b/base/api/modules.go @@ -0,0 +1,49 @@ +package api + +import ( + "time" + + "github.com/safing/portmaster/base/modules" +) + +// ModuleHandler specifies the interface for API endpoints that are bound to a module. +type ModuleHandler interface { + BelongsTo() *modules.Module +} + +const ( + moduleCheckMaxWait = 10 * time.Second + moduleCheckTickDuration = 500 * time.Millisecond +) + +// moduleIsReady checks if the given module is online and http requests can be +// sent its way. If the module is not online already, it will wait for a short +// duration for it to come online. +func moduleIsReady(m *modules.Module) (ok bool) { + // Check if we are given a module. + if m == nil { + // If no module is given, we assume that the handler has not been linked to + // a module, and we can safely continue with the request. + return true + } + + // Check if the module is online. + if m.Online() { + return true + } + + // Check if the module will come online. + if m.OnlineSoon() { + var i time.Duration + for i = 0; i < moduleCheckMaxWait; i += moduleCheckTickDuration { + // Wait a little. + time.Sleep(moduleCheckTickDuration) + // Check if module is now online. + if m.Online() { + return true + } + } + } + + return false +} diff --git a/base/api/request.go b/base/api/request.go new file mode 100644 index 000000000..125876474 --- /dev/null +++ b/base/api/request.go @@ -0,0 +1,60 @@ +package api + +import ( + "fmt" + "net/http" + + "github.com/gorilla/mux" + + "github.com/safing/portmaster/base/log" +) + +// Request is a support struct to pool more request related information. +type Request struct { + // Request is the http request. + *http.Request + + // InputData contains the request body for write operations. + InputData []byte + + // Route of this request. + Route *mux.Route + + // URLVars contains the URL variables extracted by the gorilla mux. + URLVars map[string]string + + // AuthToken is the request-side authentication token assigned. + AuthToken *AuthToken + + // ResponseHeader holds the response header. + ResponseHeader http.Header + + // HandlerCache can be used by handlers to cache data between handlers within a request. + HandlerCache interface{} +} + +// apiRequestContextKey is a key used for the context key/value storage. +type apiRequestContextKey struct{} + +// RequestContextKey is the key used to add the API request to the context. +var RequestContextKey = apiRequestContextKey{} + +// GetAPIRequest returns the API Request of the given http request. +func GetAPIRequest(r *http.Request) *Request { + ar, ok := r.Context().Value(RequestContextKey).(*Request) + if ok { + return ar + } + return nil +} + +// TextResponse writes a text response. +func TextResponse(w http.ResponseWriter, r *http.Request, text string) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.WriteHeader(http.StatusOK) + _, err := fmt.Fprintln(w, text) + if err != nil { + log.Tracer(r.Context()).Warningf("api: failed to write text response: %s", err) + } +} diff --git a/base/api/router.go b/base/api/router.go new file mode 100644 index 000000000..d8c3c3d20 --- /dev/null +++ b/base/api/router.go @@ -0,0 +1,334 @@ +package api + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + "path" + "runtime/debug" + "strings" + "sync" + "time" + + "github.com/gorilla/mux" + + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/utils" +) + +// EnableServer defines if the HTTP server should be started. +var EnableServer = true + +var ( + // mainMux is the main mux router. + mainMux = mux.NewRouter() + + // server is the main server. + server = &http.Server{ + ReadHeaderTimeout: 10 * time.Second, + } + handlerLock sync.RWMutex + + allowedDevCORSOrigins = []string{ + "127.0.0.1", + "localhost", + } +) + +// RegisterHandler registers a handler with the API endpoint. +func RegisterHandler(path string, handler http.Handler) *mux.Route { + handlerLock.Lock() + defer handlerLock.Unlock() + return mainMux.Handle(path, handler) +} + +// RegisterHandleFunc registers a handle function with the API endpoint. +func RegisterHandleFunc(path string, handleFunc func(http.ResponseWriter, *http.Request)) *mux.Route { + handlerLock.Lock() + defer handlerLock.Unlock() + return mainMux.HandleFunc(path, handleFunc) +} + +func startServer() { + // Check if server is enabled. + if !EnableServer { + return + } + + // Configure server. + server.Addr = listenAddressConfig() + server.Handler = &mainHandler{ + // TODO: mainMux should not be modified anymore. + mux: mainMux, + } + + // Start server manager. + module.StartServiceWorker("http server manager", 0, serverManager) +} + +func stopServer() error { + // Check if server is enabled. + if !EnableServer { + return nil + } + + if server.Addr != "" { + return server.Shutdown(context.Background()) + } + + return nil +} + +// Serve starts serving the API endpoint. +func serverManager(_ context.Context) error { + // start serving + log.Infof("api: starting to listen on %s", server.Addr) + backoffDuration := 10 * time.Second + for { + // always returns an error + err := module.RunWorker("http endpoint", func(ctx context.Context) error { + return server.ListenAndServe() + }) + // return on shutdown error + if errors.Is(err, http.ErrServerClosed) { + return nil + } + // log error and restart + log.Errorf("api: http endpoint failed: %s - restarting in %s", err, backoffDuration) + time.Sleep(backoffDuration) + } +} + +type mainHandler struct { + mux *mux.Router +} + +func (mh *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + _ = module.RunWorker("http request", func(_ context.Context) error { + return mh.handle(w, r) + }) +} + +func (mh *mainHandler) handle(w http.ResponseWriter, r *http.Request) error { + // Setup context trace logging. + ctx, tracer := log.AddTracer(r.Context()) + // Add request context. + apiRequest := &Request{ + Request: r, + } + ctx = context.WithValue(ctx, RequestContextKey, apiRequest) + // Add context back to request. + r = r.WithContext(ctx) + lrw := NewLoggingResponseWriter(w, r) + + tracer.Tracef("api request: %s ___ %s %s", r.RemoteAddr, lrw.Request.Method, r.RequestURI) + defer func() { + // Log request status. + if lrw.Status != 0 { + // If lrw.Status is 0, the request may have been hijacked. + tracer.Debugf("api request: %s %d %s %s", lrw.Request.RemoteAddr, lrw.Status, lrw.Request.Method, lrw.Request.RequestURI) + } + tracer.Submit() + }() + + // Add security headers. + w.Header().Set("Referrer-Policy", "same-origin") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("X-Frame-Options", "deny") + w.Header().Set("X-XSS-Protection", "1; mode=block") + w.Header().Set("X-DNS-Prefetch-Control", "off") + + // Add CSP Header in production mode. + if !devMode() { + w.Header().Set( + "Content-Security-Policy", + "default-src 'self'; "+ + "connect-src https://*.safing.io 'self'; "+ + "style-src 'self' 'unsafe-inline'; "+ + "img-src 'self' data: blob:", + ) + } + + // Check Cross-Origin Requests. + origin := r.Header.Get("Origin") + isPreflighCheck := false + if origin != "" { + + // Parse origin URL. + originURL, err := url.Parse(origin) + if err != nil { + tracer.Warningf("api: denied request from %s: failed to parse origin header: %s", r.RemoteAddr, err) + http.Error(lrw, "Invalid Origin.", http.StatusForbidden) + return nil + } + + // Check if the Origin matches the Host. + switch { + case originURL.Host == r.Host: + // Origin (with port) matches Host. + case originURL.Hostname() == r.Host: + // Origin (without port) matches Host. + case originURL.Scheme == "chrome-extension": + // Allow access for the browser extension + // TODO(ppacher): + // This currently allows access from any browser extension. + // Can we reduce that to only our browser extension? + // Also, what do we need to support Firefox? + case devMode() && + utils.StringInSlice(allowedDevCORSOrigins, originURL.Hostname()): + // We are in dev mode and the request is coming from the allowed + // development origins. + default: + // Origin and Host do NOT match! + tracer.Warningf("api: denied request from %s: Origin (`%s`) and Host (`%s`) do not match", r.RemoteAddr, origin, r.Host) + http.Error(lrw, "Cross-Origin Request Denied.", http.StatusForbidden) + return nil + + // If the Host header has a port, and the Origin does not, requests will + // also end up here, as we cannot properly check for equality. + } + + // Add Cross-Site Headers now as we need them in any case now. + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Access-Control-Allow-Methods", "*") + w.Header().Set("Access-Control-Allow-Headers", "*") + w.Header().Set("Access-Control-Allow-Credentials", "true") + w.Header().Set("Access-Control-Expose-Headers", "*") + w.Header().Set("Access-Control-Max-Age", "60") + w.Header().Add("Vary", "Origin") + + // if there's a Access-Control-Request-Method header this is a Preflight check. + // In that case, we will just check if the preflighMethod is allowed and then return + // success here + if preflighMethod := r.Header.Get("Access-Control-Request-Method"); r.Method == http.MethodOptions && preflighMethod != "" { + isPreflighCheck = true + } + } + + // Clean URL. + cleanedRequestPath := cleanRequestPath(r.URL.Path) + + // If the cleaned URL differs from the original one, redirect to there. + if r.URL.Path != cleanedRequestPath { + redirURL := *r.URL + redirURL.Path = cleanedRequestPath + http.Redirect(lrw, r, redirURL.String(), http.StatusMovedPermanently) + return nil + } + + // Get handler for request. + // Gorilla does not support handling this on our own very well. + // See github.com/gorilla/mux.ServeHTTP for reference. + var match mux.RouteMatch + var handler http.Handler + if mh.mux.Match(r, &match) { + handler = match.Handler + apiRequest.Route = match.Route + apiRequest.URLVars = match.Vars + } + switch { + case match.MatchErr == nil: + // All good. + case errors.Is(match.MatchErr, mux.ErrMethodMismatch): + http.Error(lrw, "Method not allowed.", http.StatusMethodNotAllowed) + return nil + default: + tracer.Debug("api: no handler registered for this path") + http.Error(lrw, "Not found.", http.StatusNotFound) + return nil + } + + // Be sure that URLVars always is a map. + if apiRequest.URLVars == nil { + apiRequest.URLVars = make(map[string]string) + } + + // Check method. + _, readMethod, ok := getEffectiveMethod(r) + if !ok { + http.Error(lrw, "Method not allowed.", http.StatusMethodNotAllowed) + return nil + } + + // At this point we know the method is allowed and there's a handler for the request. + // If this is just a CORS-Preflight, we'll accept the request with StatusOK now. + // There's no point in trying to authenticate the request because the Browser will + // not send authentication along a preflight check. + if isPreflighCheck && handler != nil { + lrw.WriteHeader(http.StatusOK) + return nil + } + + // Check authentication. + apiRequest.AuthToken = authenticateRequest(lrw, r, handler, readMethod) + if apiRequest.AuthToken == nil { + // Authenticator already replied. + return nil + } + + // Wait for the owning module to be ready. + if moduleHandler, ok := handler.(ModuleHandler); ok { + if !moduleIsReady(moduleHandler.BelongsTo()) { + http.Error(lrw, "The API endpoint is not ready yet. Reload (F5) to try again.", http.StatusServiceUnavailable) + return nil + } + } + + // Check if we have a handler. + if handler == nil { + http.Error(lrw, "Not found.", http.StatusNotFound) + return nil + } + + // Format panics in handler. + defer func() { + if panicValue := recover(); panicValue != nil { + // Report failure via module system. + me := module.NewPanicError("api request", "custom", panicValue) + me.Report() + // Respond with a server error. + if devMode() { + http.Error( + lrw, + fmt.Sprintf( + "Internal Server Error: %s\n\n%s", + panicValue, + debug.Stack(), + ), + http.StatusInternalServerError, + ) + } else { + http.Error(lrw, "Internal Server Error.", http.StatusInternalServerError) + } + } + }() + + // Handle with registered handler. + handler.ServeHTTP(lrw, r) + + return nil +} + +// cleanRequestPath cleans and returns a request URL. +func cleanRequestPath(requestPath string) string { + // If the request URL is empty, return a request for "root". + if requestPath == "" || requestPath == "/" { + return "/" + } + // If the request URL does not start with a slash, prepend it. + if !strings.HasPrefix(requestPath, "/") { + requestPath = "/" + requestPath + } + + // Clean path to remove any relative parts. + cleanedRequestPath := path.Clean(requestPath) + // Because path.Clean removes a trailing slash, we need to add it back here + // if the original URL had one. + if strings.HasSuffix(requestPath, "/") { + cleanedRequestPath += "/" + } + + return cleanedRequestPath +} diff --git a/base/api/testclient/root/index.html b/base/api/testclient/root/index.html new file mode 100644 index 000000000..01e9e987d --- /dev/null +++ b/base/api/testclient/root/index.html @@ -0,0 +1,49 @@ + + + + + + + + + yeeee + + diff --git a/base/api/testclient/serve.go b/base/api/testclient/serve.go new file mode 100644 index 000000000..3d4551b1d --- /dev/null +++ b/base/api/testclient/serve.go @@ -0,0 +1,11 @@ +package testclient + +import ( + "net/http" + + "github.com/safing/portmaster/base/api" +) + +func init() { + api.RegisterHandler("/test/", http.StripPrefix("/test/", http.FileServer(http.Dir("./api/testclient/root/")))) +} diff --git a/base/apprise/notify.go b/base/apprise/notify.go new file mode 100644 index 000000000..2040171c7 --- /dev/null +++ b/base/apprise/notify.go @@ -0,0 +1,167 @@ +package apprise + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "sync" + + "github.com/safing/portmaster/base/utils" +) + +// Notifier sends messsages to an Apprise API. +type Notifier struct { + // URL defines the Apprise API endpoint. + URL string + + // DefaultType defines the default message type. + DefaultType MsgType + + // DefaultTag defines the default message tag. + DefaultTag string + + // DefaultFormat defines the default message format. + DefaultFormat MsgFormat + + // AllowUntagged defines if untagged messages are allowed, + // which are sent to all configured apprise endpoints. + AllowUntagged bool + + client *http.Client + clientLock sync.Mutex +} + +// Message represents the message to be sent to the Apprise API. +type Message struct { + // Title is an optional title to go along with the body. + Title string `json:"title,omitempty"` + + // Body is the main message content. This is the only required field. + Body string `json:"body"` + + // Type defines the message type you want to send as. + // The valid options are info, success, warning, and failure. + // If no type is specified then info is the default value used. + Type MsgType `json:"type,omitempty"` + + // Tag is used to notify only those tagged accordingly. + // Use a comma (,) to OR your tags and a space ( ) to AND them. + Tag string `json:"tag,omitempty"` + + // Format optionally identifies the text format of the data you're feeding Apprise. + // The valid options are text, markdown, html. + // The default value if nothing is specified is text. + Format MsgFormat `json:"format,omitempty"` +} + +// MsgType defines the message type. +type MsgType string + +// Message Types. +const ( + TypeInfo MsgType = "info" + TypeSuccess MsgType = "success" + TypeWarning MsgType = "warning" + TypeFailure MsgType = "failure" +) + +// MsgFormat defines the message format. +type MsgFormat string + +// Message Formats. +const ( + FormatText MsgFormat = "text" + FormatMarkdown MsgFormat = "markdown" + FormatHTML MsgFormat = "html" +) + +type errorResponse struct { + Error string `json:"error"` +} + +// Send sends a message to the Apprise API. +func (n *Notifier) Send(ctx context.Context, m *Message) error { + // Check if the message has a body. + if m.Body == "" { + return errors.New("the message must have a body") + } + + // Apply notifier defaults. + n.applyDefaults(m) + + // Check if the message is tagged. + if m.Tag == "" && !n.AllowUntagged { + return errors.New("the message must have a tag") + } + + // Marshal the message to JSON. + payload, err := json.Marshal(m) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + + // Create request. + request, err := http.NewRequestWithContext(ctx, http.MethodPost, n.URL, bytes.NewReader(payload)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + request.Header.Set("Content-Type", "application/json") + + // Send message to API. + resp, err := n.getClient().Do(request) + if err != nil { + return fmt.Errorf("failed to send message: %w", err) + } + defer resp.Body.Close() //nolint:errcheck,gosec + switch resp.StatusCode { + case http.StatusOK, http.StatusCreated, http.StatusNoContent, http.StatusAccepted: + return nil + default: + // Try to tease body contents. + if body, err := io.ReadAll(resp.Body); err == nil && len(body) > 0 { + // Try to parse json response. + errorResponse := &errorResponse{} + if err := json.Unmarshal(body, errorResponse); err == nil && errorResponse.Error != "" { + return fmt.Errorf("failed to send message: apprise returned %q with an error message: %s", resp.Status, errorResponse.Error) + } + return fmt.Errorf("failed to send message: %s (body teaser: %s)", resp.Status, utils.SafeFirst16Bytes(body)) + } + return fmt.Errorf("failed to send message: %s", resp.Status) + } +} + +func (n *Notifier) applyDefaults(m *Message) { + if m.Type == "" { + m.Type = n.DefaultType + } + if m.Tag == "" { + m.Tag = n.DefaultTag + } + if m.Format == "" { + m.Format = n.DefaultFormat + } +} + +// SetClient sets a custom http client for accessing the Apprise API. +func (n *Notifier) SetClient(client *http.Client) { + n.clientLock.Lock() + defer n.clientLock.Unlock() + + n.client = client +} + +func (n *Notifier) getClient() *http.Client { + n.clientLock.Lock() + defer n.clientLock.Unlock() + + // Create client if needed. + if n.client == nil { + n.client = &http.Client{} + } + + return n.client +} diff --git a/base/config/basic_config.go b/base/config/basic_config.go new file mode 100644 index 000000000..7898df127 --- /dev/null +++ b/base/config/basic_config.go @@ -0,0 +1,113 @@ +package config + +import ( + "context" + "flag" + + "github.com/safing/portmaster/base/log" +) + +// Configuration Keys. +var ( + CfgDevModeKey = "core/devMode" + defaultDevMode bool + + CfgLogLevel = "core/log/level" + defaultLogLevel = log.InfoLevel.String() + logLevel StringOption +) + +func init() { + flag.BoolVar(&defaultDevMode, "devmode", false, "enable development mode; configuration is stronger") +} + +func registerBasicOptions() error { + // Get the default log level from the log package. + defaultLogLevel = log.GetLogLevel().Name() + + // Register logging setting. + // The log package cannot do that, as it would trigger and import loop. + if err := Register(&Option{ + Name: "Log Level", + Key: CfgLogLevel, + Description: "Configure the logging level.", + OptType: OptTypeString, + ExpertiseLevel: ExpertiseLevelDeveloper, + ReleaseLevel: ReleaseLevelStable, + DefaultValue: defaultLogLevel, + Annotations: Annotations{ + DisplayOrderAnnotation: 513, + DisplayHintAnnotation: DisplayHintOneOf, + CategoryAnnotation: "Development", + }, + PossibleValues: []PossibleValue{ + { + Name: "Critical", + Value: "critical", + Description: "The critical level only logs errors that lead to a partial, but imminent failure.", + }, + { + Name: "Error", + Value: "error", + Description: "The error level logs errors that potentially break functionality. Everything logged by the critical level is included here too.", + }, + { + Name: "Warning", + Value: "warning", + Description: "The warning level logs minor errors and worse. Everything logged by the error level is included here too.", + }, + { + Name: "Info", + Value: "info", + Description: "The info level logs the main events that are going on and are interesting to the user. Everything logged by the warning level is included here too.", + }, + { + Name: "Debug", + Value: "debug", + Description: "The debug level logs some additional debugging details. Everything logged by the info level is included here too.", + }, + { + Name: "Trace", + Value: "trace", + Description: "The trace level logs loads of detailed information as well as operation and request traces. Everything logged by the debug level is included here too.", + }, + }, + }); err != nil { + return err + } + logLevel = GetAsString(CfgLogLevel, defaultLogLevel) + + // Register to hook to update the log level. + if err := module.RegisterEventHook( + "config", + ChangeEvent, + "update log level", + setLogLevel, + ); err != nil { + return err + } + + return Register(&Option{ + Name: "Development Mode", + Key: CfgDevModeKey, + Description: "In Development Mode, security restrictions are lifted/softened to enable unrestricted access for debugging and testing purposes.", + OptType: OptTypeBool, + ExpertiseLevel: ExpertiseLevelDeveloper, + ReleaseLevel: ReleaseLevelStable, + DefaultValue: defaultDevMode, + Annotations: Annotations{ + DisplayOrderAnnotation: 512, + CategoryAnnotation: "Development", + }, + }) +} + +func loadLogLevel() error { + return setDefaultConfigOption(CfgLogLevel, log.GetLogLevel().Name(), false) +} + +func setLogLevel(ctx context.Context, data interface{}) error { + log.SetLogLevel(log.ParseLevel(logLevel())) + + return nil +} diff --git a/base/config/database.go b/base/config/database.go new file mode 100644 index 000000000..8e3e6564f --- /dev/null +++ b/base/config/database.go @@ -0,0 +1,169 @@ +package config + +import ( + "errors" + "sort" + "strings" + + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/iterator" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/database/storage" + "github.com/safing/portmaster/base/log" +) + +var dbController *database.Controller + +// StorageInterface provices a storage.Interface to the configuration manager. +type StorageInterface struct { + storage.InjectBase +} + +// Get returns a database record. +func (s *StorageInterface) Get(key string) (record.Record, error) { + opt, err := GetOption(key) + if err != nil { + return nil, storage.ErrNotFound + } + + return opt.Export() +} + +// Put stores a record in the database. +func (s *StorageInterface) Put(r record.Record) (record.Record, error) { + if r.Meta().Deleted > 0 { + return r, setConfigOption(r.DatabaseKey(), nil, false) + } + + acc := r.GetAccessor(r) + if acc == nil { + return nil, errors.New("invalid data") + } + + val, ok := acc.Get("Value") + if !ok || val == nil { + err := setConfigOption(r.DatabaseKey(), nil, false) + if err != nil { + return nil, err + } + return s.Get(r.DatabaseKey()) + } + + option, err := GetOption(r.DatabaseKey()) + if err != nil { + return nil, err + } + + var value interface{} + switch option.OptType { + case OptTypeString: + value, ok = acc.GetString("Value") + case OptTypeStringArray: + value, ok = acc.GetStringArray("Value") + case OptTypeInt: + value, ok = acc.GetInt("Value") + case OptTypeBool: + value, ok = acc.GetBool("Value") + case optTypeAny: + ok = false + } + if !ok { + return nil, errors.New("received invalid value in \"Value\"") + } + + if err := setConfigOption(r.DatabaseKey(), value, false); err != nil { + return nil, err + } + return option.Export() +} + +// Delete deletes a record from the database. +func (s *StorageInterface) Delete(key string) error { + return setConfigOption(key, nil, false) +} + +// Query returns a an iterator for the supplied query. +func (s *StorageInterface) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { + optionsLock.RLock() + defer optionsLock.RUnlock() + + it := iterator.New() + var opts []*Option + for _, opt := range options { + if strings.HasPrefix(opt.Key, q.DatabaseKeyPrefix()) { + opts = append(opts, opt) + } + } + + go s.processQuery(it, opts) + + return it, nil +} + +func (s *StorageInterface) processQuery(it *iterator.Iterator, opts []*Option) { + sort.Sort(sortByKey(opts)) + + for _, opt := range opts { + r, err := opt.Export() + if err != nil { + it.Finish(err) + return + } + it.Next <- r + } + + it.Finish(nil) +} + +// ReadOnly returns whether the database is read only. +func (s *StorageInterface) ReadOnly() bool { + return false +} + +func registerAsDatabase() error { + _, err := database.Register(&database.Database{ + Name: "config", + Description: "Configuration Manager", + StorageType: "injected", + }) + if err != nil { + return err + } + + controller, err := database.InjectDatabase("config", &StorageInterface{}) + if err != nil { + return err + } + + dbController = controller + return nil +} + +// handleOptionUpdate updates the expertise and release level options, +// if required, and eventually pushes a update for the option. +// The caller must hold the option lock. +func handleOptionUpdate(option *Option, push bool) { + if expertiseLevelOptionFlag.IsSet() && option == expertiseLevelOption { + updateExpertiseLevel() + } + + if releaseLevelOptionFlag.IsSet() && option == releaseLevelOption { + updateReleaseLevel() + } + + if push { + pushUpdate(option) + } +} + +// pushUpdate pushes an database update notification for option. +// The caller must hold the option lock. +func pushUpdate(option *Option) { + r, err := option.export() + if err != nil { + log.Errorf("failed to export option to push update: %s", err) + } else { + dbController.PushUpdate(r) + } +} diff --git a/base/config/doc.go b/base/config/doc.go new file mode 100644 index 000000000..6c023dca7 --- /dev/null +++ b/base/config/doc.go @@ -0,0 +1,2 @@ +// Package config provides a versatile configuration management system. +package config diff --git a/base/config/expertise.go b/base/config/expertise.go new file mode 100644 index 000000000..f6f786160 --- /dev/null +++ b/base/config/expertise.go @@ -0,0 +1,104 @@ +package config + +import ( + "sync/atomic" + + "github.com/tevino/abool" +) + +// ExpertiseLevel allows to group settings by user expertise. +// It's useful if complex or technical settings should be hidden +// from the average user while still allowing experts and developers +// to change deep configuration settings. +type ExpertiseLevel uint8 + +// Expertise Level constants. +const ( + ExpertiseLevelUser ExpertiseLevel = 0 + ExpertiseLevelExpert ExpertiseLevel = 1 + ExpertiseLevelDeveloper ExpertiseLevel = 2 + + ExpertiseLevelNameUser = "user" + ExpertiseLevelNameExpert = "expert" + ExpertiseLevelNameDeveloper = "developer" + + expertiseLevelKey = "core/expertiseLevel" +) + +var ( + expertiseLevelOption *Option + expertiseLevel = new(int32) + expertiseLevelOptionFlag = abool.New() +) + +func init() { + registerExpertiseLevelOption() +} + +func registerExpertiseLevelOption() { + expertiseLevelOption = &Option{ + Name: "UI Mode", + Key: expertiseLevelKey, + Description: "Control the default amount of settings and information shown. Hidden settings are still in effect. Can be changed temporarily in the top right corner.", + OptType: OptTypeString, + ExpertiseLevel: ExpertiseLevelUser, + ReleaseLevel: ReleaseLevelStable, + DefaultValue: ExpertiseLevelNameUser, + Annotations: Annotations{ + DisplayOrderAnnotation: -16, + DisplayHintAnnotation: DisplayHintOneOf, + CategoryAnnotation: "User Interface", + }, + PossibleValues: []PossibleValue{ + { + Name: "Simple Interface", + Value: ExpertiseLevelNameUser, + Description: "Hide complex settings and information.", + }, + { + Name: "Advanced Interface", + Value: ExpertiseLevelNameExpert, + Description: "Show technical details.", + }, + { + Name: "Developer Interface", + Value: ExpertiseLevelNameDeveloper, + Description: "Developer mode. Please be careful!", + }, + }, + } + + err := Register(expertiseLevelOption) + if err != nil { + panic(err) + } + + expertiseLevelOptionFlag.Set() +} + +func updateExpertiseLevel() { + // get value + value := expertiseLevelOption.activeFallbackValue + if expertiseLevelOption.activeValue != nil { + value = expertiseLevelOption.activeValue + } + if expertiseLevelOption.activeDefaultValue != nil { + value = expertiseLevelOption.activeDefaultValue + } + // set atomic value + switch value.stringVal { + case ExpertiseLevelNameUser: + atomic.StoreInt32(expertiseLevel, int32(ExpertiseLevelUser)) + case ExpertiseLevelNameExpert: + atomic.StoreInt32(expertiseLevel, int32(ExpertiseLevelExpert)) + case ExpertiseLevelNameDeveloper: + atomic.StoreInt32(expertiseLevel, int32(ExpertiseLevelDeveloper)) + default: + atomic.StoreInt32(expertiseLevel, int32(ExpertiseLevelUser)) + } +} + +// GetExpertiseLevel returns the current active expertise level. +func GetExpertiseLevel() uint8 { + return uint8(atomic.LoadInt32(expertiseLevel)) +} diff --git a/base/config/get-safe.go b/base/config/get-safe.go new file mode 100644 index 000000000..f03d4532e --- /dev/null +++ b/base/config/get-safe.go @@ -0,0 +1,112 @@ +package config + +import "sync" + +type safe struct{} + +// Concurrent makes concurrency safe get methods available. +var Concurrent = &safe{} + +// GetAsString returns a function that returns the wanted string with high performance. +func (cs *safe) GetAsString(name string, fallback string) StringOption { + valid := getValidityFlag() + option, valueCache := getValueCache(name, nil, OptTypeString) + value := fallback + if valueCache != nil { + value = valueCache.stringVal + } + var lock sync.Mutex + + return func() string { + lock.Lock() + defer lock.Unlock() + if !valid.IsSet() { + valid = getValidityFlag() + option, valueCache = getValueCache(name, option, OptTypeString) + if valueCache != nil { + value = valueCache.stringVal + } else { + value = fallback + } + } + return value + } +} + +// GetAsStringArray returns a function that returns the wanted string with high performance. +func (cs *safe) GetAsStringArray(name string, fallback []string) StringArrayOption { + valid := getValidityFlag() + option, valueCache := getValueCache(name, nil, OptTypeStringArray) + value := fallback + if valueCache != nil { + value = valueCache.stringArrayVal + } + var lock sync.Mutex + + return func() []string { + lock.Lock() + defer lock.Unlock() + if !valid.IsSet() { + valid = getValidityFlag() + option, valueCache = getValueCache(name, option, OptTypeStringArray) + if valueCache != nil { + value = valueCache.stringArrayVal + } else { + value = fallback + } + } + return value + } +} + +// GetAsInt returns a function that returns the wanted int with high performance. +func (cs *safe) GetAsInt(name string, fallback int64) IntOption { + valid := getValidityFlag() + option, valueCache := getValueCache(name, nil, OptTypeInt) + value := fallback + if valueCache != nil { + value = valueCache.intVal + } + var lock sync.Mutex + + return func() int64 { + lock.Lock() + defer lock.Unlock() + if !valid.IsSet() { + valid = getValidityFlag() + option, valueCache = getValueCache(name, option, OptTypeInt) + if valueCache != nil { + value = valueCache.intVal + } else { + value = fallback + } + } + return value + } +} + +// GetAsBool returns a function that returns the wanted int with high performance. +func (cs *safe) GetAsBool(name string, fallback bool) BoolOption { + valid := getValidityFlag() + option, valueCache := getValueCache(name, nil, OptTypeBool) + value := fallback + if valueCache != nil { + value = valueCache.boolVal + } + var lock sync.Mutex + + return func() bool { + lock.Lock() + defer lock.Unlock() + if !valid.IsSet() { + valid = getValidityFlag() + option, valueCache = getValueCache(name, option, OptTypeBool) + if valueCache != nil { + value = valueCache.boolVal + } else { + value = fallback + } + } + return value + } +} diff --git a/base/config/get.go b/base/config/get.go new file mode 100644 index 000000000..3d41a2184 --- /dev/null +++ b/base/config/get.go @@ -0,0 +1,174 @@ +package config + +import ( + "github.com/safing/portmaster/base/log" +) + +type ( + // StringOption defines the returned function by GetAsString. + StringOption func() string + // StringArrayOption defines the returned function by GetAsStringArray. + StringArrayOption func() []string + // IntOption defines the returned function by GetAsInt. + IntOption func() int64 + // BoolOption defines the returned function by GetAsBool. + BoolOption func() bool +) + +func getValueCache(name string, option *Option, requestedType OptionType) (*Option, *valueCache) { + // get option + if option == nil { + var err error + option, err = GetOption(name) + if err != nil { + log.Errorf("config: request for unregistered option: %s", name) + return nil, nil + } + } + + // Check the option type, no locking required as + // OptType is immutable once it is set + if requestedType != option.OptType { + log.Errorf("config: bad type: requested %s as %s, but is %s", name, getTypeName(requestedType), getTypeName(option.OptType)) + return option, nil + } + + option.Lock() + defer option.Unlock() + + // check release level + if option.ReleaseLevel <= getReleaseLevel() && option.activeValue != nil { + return option, option.activeValue + } + + if option.activeDefaultValue != nil { + return option, option.activeDefaultValue + } + + return option, option.activeFallbackValue +} + +// GetAsString returns a function that returns the wanted string with high performance. +func GetAsString(name string, fallback string) StringOption { + valid := getValidityFlag() + option, valueCache := getValueCache(name, nil, OptTypeString) + value := fallback + if valueCache != nil { + value = valueCache.stringVal + } + + return func() string { + if !valid.IsSet() { + valid = getValidityFlag() + option, valueCache = getValueCache(name, option, OptTypeString) + if valueCache != nil { + value = valueCache.stringVal + } else { + value = fallback + } + } + return value + } +} + +// GetAsStringArray returns a function that returns the wanted string with high performance. +func GetAsStringArray(name string, fallback []string) StringArrayOption { + valid := getValidityFlag() + option, valueCache := getValueCache(name, nil, OptTypeStringArray) + value := fallback + if valueCache != nil { + value = valueCache.stringArrayVal + } + + return func() []string { + if !valid.IsSet() { + valid = getValidityFlag() + option, valueCache = getValueCache(name, option, OptTypeStringArray) + if valueCache != nil { + value = valueCache.stringArrayVal + } else { + value = fallback + } + } + return value + } +} + +// GetAsInt returns a function that returns the wanted int with high performance. +func GetAsInt(name string, fallback int64) IntOption { + valid := getValidityFlag() + option, valueCache := getValueCache(name, nil, OptTypeInt) + value := fallback + if valueCache != nil { + value = valueCache.intVal + } + + return func() int64 { + if !valid.IsSet() { + valid = getValidityFlag() + option, valueCache = getValueCache(name, option, OptTypeInt) + if valueCache != nil { + value = valueCache.intVal + } else { + value = fallback + } + } + return value + } +} + +// GetAsBool returns a function that returns the wanted int with high performance. +func GetAsBool(name string, fallback bool) BoolOption { + valid := getValidityFlag() + option, valueCache := getValueCache(name, nil, OptTypeBool) + value := fallback + if valueCache != nil { + value = valueCache.boolVal + } + + return func() bool { + if !valid.IsSet() { + valid = getValidityFlag() + option, valueCache = getValueCache(name, option, OptTypeBool) + if valueCache != nil { + value = valueCache.boolVal + } else { + value = fallback + } + } + return value + } +} + +/* +func getAndFindValue(key string) interface{} { + optionsLock.RLock() + option, ok := options[key] + optionsLock.RUnlock() + if !ok { + log.Errorf("config: request for unregistered option: %s", key) + return nil + } + + return option.findValue() +} +*/ + +/* +// findValue finds the preferred value in the user or default config. +func (option *Option) findValue() interface{} { + // lock option + option.Lock() + defer option.Unlock() + + if option.ReleaseLevel <= getReleaseLevel() && option.activeValue != nil { + return option.activeValue + } + + if option.activeDefaultValue != nil { + return option.activeDefaultValue + } + + return option.DefaultValue +} +*/ diff --git a/base/config/get_test.go b/base/config/get_test.go new file mode 100644 index 000000000..810631abc --- /dev/null +++ b/base/config/get_test.go @@ -0,0 +1,368 @@ +package config + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/safing/portmaster/base/log" +) + +func parseAndReplaceConfig(jsonData string) error { + m, err := JSONToMap([]byte(jsonData)) + if err != nil { + return err + } + + validationErrors, _ := ReplaceConfig(m) + if len(validationErrors) > 0 { + return fmt.Errorf("%d errors, first: %w", len(validationErrors), validationErrors[0]) + } + return nil +} + +func parseAndReplaceDefaultConfig(jsonData string) error { + m, err := JSONToMap([]byte(jsonData)) + if err != nil { + return err + } + + validationErrors, _ := ReplaceDefaultConfig(m) + if len(validationErrors) > 0 { + return fmt.Errorf("%d errors, first: %w", len(validationErrors), validationErrors[0]) + } + return nil +} + +func quickRegister(t *testing.T, key string, optType OptionType, defaultValue interface{}) { + t.Helper() + + err := Register(&Option{ + Name: key, + Key: key, + Description: "test config", + ReleaseLevel: ReleaseLevelStable, + ExpertiseLevel: ExpertiseLevelUser, + OptType: optType, + DefaultValue: defaultValue, + }) + if err != nil { + t.Fatal(err) + } +} + +func TestGet(t *testing.T) { //nolint:paralleltest + // reset + options = make(map[string]*Option) + + err := log.Start() + if err != nil { + t.Fatal(err) + } + + quickRegister(t, "monkey", OptTypeString, "c") + quickRegister(t, "zebras/zebra", OptTypeStringArray, []string{"a", "b"}) + quickRegister(t, "elephant", OptTypeInt, -1) + quickRegister(t, "hot", OptTypeBool, false) + quickRegister(t, "cold", OptTypeBool, true) + + err = parseAndReplaceConfig(` + { + "monkey": "a", + "zebras": { + "zebra": ["black", "white"] + }, + "elephant": 2, + "hot": true, + "cold": false + } + `) + if err != nil { + t.Fatal(err) + } + + err = parseAndReplaceDefaultConfig(` + { + "monkey": "b", + "snake": "0", + "elephant": 0 + } + `) + if err != nil { + t.Fatal(err) + } + + monkey := GetAsString("monkey", "none") + if monkey() != "a" { + t.Errorf("monkey should be a, is %s", monkey()) + } + + zebra := GetAsStringArray("zebras/zebra", []string{}) + if len(zebra()) != 2 || zebra()[0] != "black" || zebra()[1] != "white" { + t.Errorf("zebra should be [\"black\", \"white\"], is %v", zebra()) + } + + elephant := GetAsInt("elephant", -1) + if elephant() != 2 { + t.Errorf("elephant should be 2, is %d", elephant()) + } + + hot := GetAsBool("hot", false) + if !hot() { + t.Errorf("hot should be true, is %v", hot()) + } + + cold := GetAsBool("cold", true) + if cold() { + t.Errorf("cold should be false, is %v", cold()) + } + + err = parseAndReplaceConfig(` + { + "monkey": "3" + } + `) + if err != nil { + t.Fatal(err) + } + + if monkey() != "3" { + t.Errorf("monkey should be 0, is %s", monkey()) + } + + if elephant() != 0 { + t.Errorf("elephant should be 0, is %d", elephant()) + } + + zebra() + hot() + + // concurrent + GetAsString("monkey", "none")() + GetAsStringArray("zebras/zebra", []string{})() + GetAsInt("elephant", -1)() + GetAsBool("hot", false)() + + // perspective + + // load data + pLoaded := make(map[string]interface{}) + err = json.Unmarshal([]byte(`{ + "monkey": "a", + "zebras": { + "zebra": ["black", "white"] + }, + "elephant": 2, + "hot": true, + "cold": false + }`), &pLoaded) + if err != nil { + t.Fatal(err) + } + + // create + p, err := NewPerspective(pLoaded) + if err != nil { + t.Fatal(err) + } + + monkeyVal, ok := p.GetAsString("monkey") + if !ok || monkeyVal != "a" { + t.Errorf("[perspective] monkey should be a, is %+v", monkeyVal) + } + + zebraVal, ok := p.GetAsStringArray("zebras/zebra") + if !ok || len(zebraVal) != 2 || zebraVal[0] != "black" || zebraVal[1] != "white" { + t.Errorf("[perspective] zebra should be [\"black\", \"white\"], is %+v", zebraVal) + } + + elephantVal, ok := p.GetAsInt("elephant") + if !ok || elephantVal != 2 { + t.Errorf("[perspective] elephant should be 2, is %+v", elephantVal) + } + + hotVal, ok := p.GetAsBool("hot") + if !ok || !hotVal { + t.Errorf("[perspective] hot should be true, is %+v", hotVal) + } + + coldVal, ok := p.GetAsBool("cold") + if !ok || coldVal { + t.Errorf("[perspective] cold should be false, is %+v", coldVal) + } +} + +func TestReleaseLevel(t *testing.T) { //nolint:paralleltest + // reset + options = make(map[string]*Option) + registerReleaseLevelOption() + + // setup + subsystemOption := &Option{ + Name: "test subsystem", + Key: "subsystem/test", + Description: "test config", + ReleaseLevel: ReleaseLevelStable, + ExpertiseLevel: ExpertiseLevelUser, + OptType: OptTypeBool, + DefaultValue: false, + } + err := Register(subsystemOption) + if err != nil { + t.Fatal(err) + } + err = SetConfigOption("subsystem/test", true) + if err != nil { + t.Fatal(err) + } + testSubsystem := GetAsBool("subsystem/test", false) + + // test option level stable + subsystemOption.ReleaseLevel = ReleaseLevelStable + err = SetConfigOption(releaseLevelKey, ReleaseLevelNameStable) + if err != nil { + t.Fatal(err) + } + if !testSubsystem() { + t.Error("should be active") + } + err = SetConfigOption(releaseLevelKey, ReleaseLevelNameBeta) + if err != nil { + t.Fatal(err) + } + if !testSubsystem() { + t.Error("should be active") + } + err = SetConfigOption(releaseLevelKey, ReleaseLevelNameExperimental) + if err != nil { + t.Fatal(err) + } + if !testSubsystem() { + t.Error("should be active") + } + + // test option level beta + subsystemOption.ReleaseLevel = ReleaseLevelBeta + err = SetConfigOption(releaseLevelKey, ReleaseLevelNameStable) + if err != nil { + t.Fatal(err) + } + if testSubsystem() { + t.Errorf("should be inactive: opt=%d system=%d", subsystemOption.ReleaseLevel, getReleaseLevel()) + } + err = SetConfigOption(releaseLevelKey, ReleaseLevelNameBeta) + if err != nil { + t.Fatal(err) + } + if !testSubsystem() { + t.Error("should be active") + } + err = SetConfigOption(releaseLevelKey, ReleaseLevelNameExperimental) + if err != nil { + t.Fatal(err) + } + if !testSubsystem() { + t.Error("should be active") + } + + // test option level experimental + subsystemOption.ReleaseLevel = ReleaseLevelExperimental + err = SetConfigOption(releaseLevelKey, ReleaseLevelNameStable) + if err != nil { + t.Fatal(err) + } + if testSubsystem() { + t.Error("should be inactive") + } + err = SetConfigOption(releaseLevelKey, ReleaseLevelNameBeta) + if err != nil { + t.Fatal(err) + } + if testSubsystem() { + t.Error("should be inactive") + } + err = SetConfigOption(releaseLevelKey, ReleaseLevelNameExperimental) + if err != nil { + t.Fatal(err) + } + if !testSubsystem() { + t.Error("should be active") + } +} + +func BenchmarkGetAsStringCached(b *testing.B) { + // reset + options = make(map[string]*Option) + + // Setup + err := parseAndReplaceConfig(`{ + "monkey": "banana" + }`) + if err != nil { + b.Fatal(err) + } + monkey := GetAsString("monkey", "no banana") + + // Reset timer for precise results + b.ResetTimer() + + // Start benchmark + for i := 0; i < b.N; i++ { + monkey() + } +} + +func BenchmarkGetAsStringRefetch(b *testing.B) { + // Setup + err := parseAndReplaceConfig(`{ + "monkey": "banana" + }`) + if err != nil { + b.Fatal(err) + } + + // Reset timer for precise results + b.ResetTimer() + + // Start benchmark + for i := 0; i < b.N; i++ { + getValueCache("monkey", nil, OptTypeString) + } +} + +func BenchmarkGetAsIntCached(b *testing.B) { + // Setup + err := parseAndReplaceConfig(`{ + "elephant": 1 + }`) + if err != nil { + b.Fatal(err) + } + elephant := GetAsInt("elephant", -1) + + // Reset timer for precise results + b.ResetTimer() + + // Start benchmark + for i := 0; i < b.N; i++ { + elephant() + } +} + +func BenchmarkGetAsIntRefetch(b *testing.B) { + // Setup + err := parseAndReplaceConfig(`{ + "elephant": 1 + }`) + if err != nil { + b.Fatal(err) + } + + // Reset timer for precise results + b.ResetTimer() + + // Start benchmark + for i := 0; i < b.N; i++ { + getValueCache("elephant", nil, OptTypeInt) + } +} diff --git a/base/config/main.go b/base/config/main.go new file mode 100644 index 000000000..324c01f92 --- /dev/null +++ b/base/config/main.go @@ -0,0 +1,141 @@ +package config + +import ( + "encoding/json" + "errors" + "flag" + "fmt" + "io/fs" + "os" + "path/filepath" + "sort" + + "github.com/safing/portmaster/base/dataroot" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/utils" + "github.com/safing/portmaster/base/utils/debug" +) + +// ChangeEvent is the name of the config change event. +const ChangeEvent = "config change" + +var ( + module *modules.Module + dataRoot *utils.DirStructure + + exportConfig bool +) + +// SetDataRoot sets the data root from which the updates module derives its paths. +func SetDataRoot(root *utils.DirStructure) { + if dataRoot == nil { + dataRoot = root + } +} + +func init() { + module = modules.Register("config", prep, start, nil, "database") + module.RegisterEvent(ChangeEvent, true) + + flag.BoolVar(&exportConfig, "export-config-options", false, "export configuration registry and exit") +} + +func prep() error { + SetDataRoot(dataroot.Root()) + if dataRoot == nil { + return errors.New("data root is not set") + } + + if exportConfig { + modules.SetCmdLineOperation(exportConfigCmd) + } + + return registerBasicOptions() +} + +func start() error { + configFilePath = filepath.Join(dataRoot.Path, "config.json") + + // Load log level from log package after it started. + err := loadLogLevel() + if err != nil { + return err + } + + err = registerAsDatabase() + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return err + } + + err = loadConfig(false) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("failed to load config file: %w", err) + } + return nil +} + +func exportConfigCmd() error { + // Reset the metrics instance name option, as the default + // is set to the current hostname. + // Config key copied from metrics.CfgOptionInstanceKey. + option, err := GetOption("core/metrics/instance") + if err == nil { + option.DefaultValue = "" + } + + data, err := json.MarshalIndent(ExportOptions(), "", " ") + if err != nil { + return err + } + + _, err = os.Stdout.Write(data) + return err +} + +// AddToDebugInfo adds all changed global config options to the given debug.Info. +func AddToDebugInfo(di *debug.Info) { + var lines []string + + // Collect all changed settings. + _ = ForEachOption(func(opt *Option) error { + opt.Lock() + defer opt.Unlock() + + if opt.ReleaseLevel <= getReleaseLevel() && opt.activeValue != nil { + if opt.Sensitive { + lines = append(lines, fmt.Sprintf("%s: [redacted]", opt.Key)) + } else { + lines = append(lines, fmt.Sprintf("%s: %v", opt.Key, opt.activeValue.getData(opt))) + } + } + + return nil + }) + sort.Strings(lines) + + // Add data as section. + di.AddSection( + fmt.Sprintf("Config: %d", len(lines)), + debug.UseCodeSection|debug.AddContentLineBreaks, + lines..., + ) +} + +// GetActiveConfigValues returns a map with the active config values. +func GetActiveConfigValues() map[string]interface{} { + values := make(map[string]interface{}) + + // Collect active values from options. + _ = ForEachOption(func(opt *Option) error { + opt.Lock() + defer opt.Unlock() + + if opt.ReleaseLevel <= getReleaseLevel() && opt.activeValue != nil { + values[opt.Key] = opt.activeValue.getData(opt) + } + + return nil + }) + + return values +} diff --git a/base/config/option.go b/base/config/option.go new file mode 100644 index 000000000..22b1c2022 --- /dev/null +++ b/base/config/option.go @@ -0,0 +1,418 @@ +package config + +import ( + "encoding/json" + "fmt" + "reflect" + "regexp" + "sync" + + "github.com/mitchellh/copystructure" + "github.com/tidwall/sjson" + + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/formats/dsd" +) + +// OptionType defines the value type of an option. +type OptionType uint8 + +// Various attribute options. Use ExternalOptType for extended types in the frontend. +const ( + optTypeAny OptionType = 0 + OptTypeString OptionType = 1 + OptTypeStringArray OptionType = 2 + OptTypeInt OptionType = 3 + OptTypeBool OptionType = 4 +) + +func getTypeName(t OptionType) string { + switch t { + case optTypeAny: + return "any" + case OptTypeString: + return "string" + case OptTypeStringArray: + return "[]string" + case OptTypeInt: + return "int" + case OptTypeBool: + return "bool" + default: + return "unknown" + } +} + +// PossibleValue defines a value that is possible for +// a configuration setting. +type PossibleValue struct { + // Name is a human readable name of the option. + Name string + // Description is a human readable description of + // this value. + Description string + // Value is the actual value of the option. The type + // must match the option's value type. + Value interface{} +} + +// Annotations can be attached to configuration options to +// provide hints for user interfaces or other systems working +// or setting configuration options. +// Annotation keys should follow the below format to ensure +// future well-known annotation additions do not conflict +// with vendor/product/package specific annoations. +// +// Format: :: //. +type Annotations map[string]interface{} + +// MigrationFunc is a function that migrates a config option value. +type MigrationFunc func(option *Option, value any) any + +// Well known annotations defined by this package. +const ( + // DisplayHintAnnotation provides a hint for the user + // interface on how to render an option. + // The value of DisplayHintAnnotation is expected to + // be a string. See DisplayHintXXXX constants below + // for a list of well-known display hint annotations. + DisplayHintAnnotation = "safing/portbase:ui:display-hint" + // DisplayOrderAnnotation provides a hint for the user + // interface in which order settings should be displayed. + // The value of DisplayOrderAnnotations is expected to be + // an number (int). + DisplayOrderAnnotation = "safing/portbase:ui:order" + // UnitAnnotations defines the SI unit of an option (if any). + UnitAnnotation = "safing/portbase:ui:unit" + // CategoryAnnotations can provide an additional category + // to each settings. This category can be used by a user + // interface to group certain options together. + // User interfaces should treat a CategoryAnnotation, if + // supported, with higher priority as a DisplayOrderAnnotation. + CategoryAnnotation = "safing/portbase:ui:category" + // SubsystemAnnotation can be used to mark an option as part + // of a module subsystem. + SubsystemAnnotation = "safing/portbase:module:subsystem" + // StackableAnnotation can be set on configuration options that + // stack on top of the default (or otherwise related) options. + // The value of StackableAnnotaiton is expected to be a boolean but + // may be extended to hold references to other options in the + // future. + StackableAnnotation = "safing/portbase:options:stackable" + // RestartPendingAnnotation is automatically set on a configuration option + // that requires a restart and has been changed. + // The value must always be a boolean with value "true". + RestartPendingAnnotation = "safing/portbase:options:restart-pending" + // QuickSettingAnnotation can be used to add quick settings to + // a configuration option. A quick setting can support the user + // by switching between pre-configured values. + // The type of a quick-setting annotation is []QuickSetting or QuickSetting. + QuickSettingsAnnotation = "safing/portbase:ui:quick-setting" + // RequiresAnnotation can be used to mark another option as a + // requirement. The type of RequiresAnnotation is []ValueRequirement + // or ValueRequirement. + RequiresAnnotation = "safing/portbase:config:requires" + // RequiresFeatureIDAnnotation can be used to mark a setting as only available + // when the user has a certain feature ID in the subscription plan. + // The type is []string or string. + RequiresFeatureIDAnnotation = "safing/portmaster:ui:config:requires-feature" + // SettablePerAppAnnotation can be used to mark a setting as settable per-app and + // is a boolean. + SettablePerAppAnnotation = "safing/portmaster:settable-per-app" + // RequiresUIReloadAnnotation can be used to inform the UI that changing the value + // of the annotated setting requires a full reload of the user interface. + // The value of this annotation does not matter as the sole presence of + // the annotation key is enough. Though, users are advised to set the value + // of this annotation to true. + RequiresUIReloadAnnotation = "safing/portmaster:ui:requires-reload" +) + +// QuickSettingsAction defines the action of a quick setting. +type QuickSettingsAction string + +const ( + // QuickReplace replaces the current setting with the one from + // the quick setting. + QuickReplace = QuickSettingsAction("replace") + // QuickMergeTop merges the value of the quick setting with the + // already configured one adding new values on the top. Merging + // is only supported for OptTypeStringArray. + QuickMergeTop = QuickSettingsAction("merge-top") + // QuickMergeBottom merges the value of the quick setting with the + // already configured one adding new values at the bottom. Merging + // is only supported for OptTypeStringArray. + QuickMergeBottom = QuickSettingsAction("merge-bottom") +) + +// QuickSetting defines a quick setting for a configuration option and +// should be used together with the QuickSettingsAnnotation. +type QuickSetting struct { + // Name is the name of the quick setting. + Name string + + // Value is the value that the quick-setting configures. It must match + // the expected value type of the annotated option. + Value interface{} + + // Action defines the action of the quick setting. + Action QuickSettingsAction +} + +// ValueRequirement defines a requirement on another configuration option. +type ValueRequirement struct { + // Key is the key of the configuration option that is required. + Key string + + // Value that is required. + Value interface{} +} + +// Values for the DisplayHintAnnotation. +const ( + // DisplayHintOneOf is used to mark an option + // as a "select"-style option. That is, only one of + // the supported values may be set. This option makes + // only sense together with the PossibleValues property + // of Option. + DisplayHintOneOf = "one-of" + // DisplayHintOrdered is used to mark a list option as ordered. + // That is, the order of items is important and a user interface + // is encouraged to provide the user with re-ordering support + // (like drag'n'drop). + DisplayHintOrdered = "ordered" + // DisplayHintFilePicker is used to mark the option as being a file, which + // should give the option to use a file picker to select a local file from disk. + DisplayHintFilePicker = "file-picker" +) + +// Option describes a configuration option. +type Option struct { + sync.Mutex + // Name holds the name of the configuration options. + // It should be human readable and is mainly used for + // presentation purposes. + // Name is considered immutable after the option has + // been created. + Name string + // Key holds the database path for the option. It should + // follow the path format `category/sub/key`. + // Key is considered immutable after the option has + // been created. + Key string + // Description holds a human readable description of the + // option and what is does. The description should be short. + // Use the Help property for a longer support text. + // Description is considered immutable after the option has + // been created. + Description string + // Help may hold a long version of the description providing + // assistance with the configuration option. + // Help is considered immutable after the option has + // been created. + Help string + // Sensitive signifies that the configuration values may contain sensitive + // content, such as authentication keys. + Sensitive bool + // OptType defines the type of the option. + // OptType is considered immutable after the option has + // been created. + OptType OptionType + // ExpertiseLevel can be used to set the required expertise + // level for the option to be displayed to a user. + // ExpertiseLevel is considered immutable after the option has + // been created. + ExpertiseLevel ExpertiseLevel + // ReleaseLevel is used to mark the stability of the option. + // ReleaseLevel is considered immutable after the option has + // been created. + ReleaseLevel ReleaseLevel + // RequiresRestart should be set to true if a modification of + // the options value requires a restart of the whole application + // to take effect. + // RequiresRestart is considered immutable after the option has + // been created. + RequiresRestart bool + // DefaultValue holds the default value of the option. Note that + // this value can be overwritten during runtime (see activeDefaultValue + // and activeFallbackValue). + // DefaultValue is considered immutable after the option has + // been created. + DefaultValue interface{} + // ValidationRegex may contain a regular expression used to validate + // the value of option. If the option type is set to OptTypeStringArray + // the validation regex is applied to all entries of the string slice. + // Note that it is recommended to keep the validation regex simple so + // it can also be used in other languages (mainly JavaScript) to provide + // a better user-experience by pre-validating the expression. + // ValidationRegex is considered immutable after the option has + // been created. + ValidationRegex string + // ValidationFunc may contain a function to validate more complex values. + // The error is returned beyond the scope of this package and may be + // displayed to a user. + ValidationFunc func(value interface{}) error `json:"-"` + // PossibleValues may be set to a slice of values that are allowed + // for this configuration setting. Note that PossibleValues makes most + // sense when ExternalOptType is set to HintOneOf + // PossibleValues is considered immutable after the option has + // been created. + PossibleValues []PossibleValue `json:",omitempty"` + // Annotations adds additional annotations to the configuration options. + // See documentation of Annotations for more information. + // Annotations is considered mutable and setting/reading annotation keys + // must be performed while the option is locked. + Annotations Annotations + // Migrations holds migration functions that are given the raw option value + // before any validation is run. The returned value is then used. + Migrations []MigrationFunc `json:"-"` + + activeValue *valueCache // runtime value (loaded from config file or set by user) + activeDefaultValue *valueCache // runtime default value (may be set internally) + activeFallbackValue *valueCache // default value from option registration + compiledRegex *regexp.Regexp +} + +// AddAnnotation adds the annotation key to option if it's not already set. +func (option *Option) AddAnnotation(key string, value interface{}) { + option.Lock() + defer option.Unlock() + + if option.Annotations == nil { + option.Annotations = make(Annotations) + } + + if _, ok := option.Annotations[key]; ok { + return + } + option.Annotations[key] = value +} + +// SetAnnotation sets the value of the annotation key overwritting an +// existing value if required. +func (option *Option) SetAnnotation(key string, value interface{}) { + option.Lock() + defer option.Unlock() + + option.setAnnotation(key, value) +} + +// setAnnotation sets the value of the annotation key overwritting an +// existing value if required. Does not lock the Option. +func (option *Option) setAnnotation(key string, value interface{}) { + if option.Annotations == nil { + option.Annotations = make(Annotations) + } + option.Annotations[key] = value +} + +// GetAnnotation returns the value of the annotation key. +func (option *Option) GetAnnotation(key string) (interface{}, bool) { + option.Lock() + defer option.Unlock() + + if option.Annotations == nil { + return nil, false + } + val, ok := option.Annotations[key] + return val, ok +} + +// AnnotationEquals returns whether the annotation of the given key matches the +// given value. +func (option *Option) AnnotationEquals(key string, value any) bool { + option.Lock() + defer option.Unlock() + + if option.Annotations == nil { + return false + } + setValue, ok := option.Annotations[key] + if !ok { + return false + } + return reflect.DeepEqual(value, setValue) +} + +// copyOrNil returns a copy of the option, or nil if copying failed. +func (option *Option) copyOrNil() *Option { + copied, err := copystructure.Copy(option) + if err != nil { + return nil + } + return copied.(*Option) //nolint:forcetypeassert +} + +// IsSetByUser returns whether the option has been set by the user. +func (option *Option) IsSetByUser() bool { + option.Lock() + defer option.Unlock() + + return option.activeValue != nil +} + +// UserValue returns the value set by the user or nil if the value has not +// been changed from the default. +func (option *Option) UserValue() any { + option.Lock() + defer option.Unlock() + + if option.activeValue == nil { + return nil + } + return option.activeValue.getData(option) +} + +// ValidateValue checks if the given value is valid for the option. +func (option *Option) ValidateValue(value any) error { + option.Lock() + defer option.Unlock() + + value = migrateValue(option, value) + if _, err := validateValue(option, value); err != nil { + return err + } + return nil +} + +// Export expors an option to a Record. +func (option *Option) Export() (record.Record, error) { + option.Lock() + defer option.Unlock() + + return option.export() +} + +func (option *Option) export() (record.Record, error) { + data, err := json.Marshal(option) + if err != nil { + return nil, err + } + + if option.activeValue != nil { + data, err = sjson.SetBytes(data, "Value", option.activeValue.getData(option)) + if err != nil { + return nil, err + } + } + + if option.activeDefaultValue != nil { + data, err = sjson.SetBytes(data, "DefaultValue", option.activeDefaultValue.getData(option)) + if err != nil { + return nil, err + } + } + + r, err := record.NewWrapper(fmt.Sprintf("config:%s", option.Key), nil, dsd.JSON, data) + if err != nil { + return nil, err + } + r.SetMeta(&record.Meta{}) + + return r, nil +} + +type sortByKey []*Option + +func (opts sortByKey) Len() int { return len(opts) } +func (opts sortByKey) Less(i, j int) bool { return opts[i].Key < opts[j].Key } +func (opts sortByKey) Swap(i, j int) { opts[i], opts[j] = opts[j], opts[i] } diff --git a/base/config/persistence.go b/base/config/persistence.go new file mode 100644 index 000000000..87fd9e91c --- /dev/null +++ b/base/config/persistence.go @@ -0,0 +1,234 @@ +package config + +import ( + "encoding/json" + "fmt" + "os" + "path" + "strings" + "sync" + + "github.com/safing/portmaster/base/log" +) + +var ( + configFilePath string + + loadedConfigValidationErrors []*ValidationError + loadedConfigValidationErrorsLock sync.Mutex +) + +// GetLoadedConfigValidationErrors returns the encountered validation errors +// from the last time loading config from disk. +func GetLoadedConfigValidationErrors() []*ValidationError { + loadedConfigValidationErrorsLock.Lock() + defer loadedConfigValidationErrorsLock.Unlock() + + return loadedConfigValidationErrors +} + +func loadConfig(requireValidConfig bool) error { + // check if persistence is configured + if configFilePath == "" { + return nil + } + + // read config file + data, err := os.ReadFile(configFilePath) + if err != nil { + return err + } + + // convert to map + newValues, err := JSONToMap(data) + if err != nil { + return err + } + + validationErrors, _ := ReplaceConfig(newValues) + if requireValidConfig && len(validationErrors) > 0 { + return fmt.Errorf("encountered %d validation errors during config loading", len(validationErrors)) + } + + // Save validation errors. + loadedConfigValidationErrorsLock.Lock() + defer loadedConfigValidationErrorsLock.Unlock() + loadedConfigValidationErrors = validationErrors + + return nil +} + +// SaveConfig saves the current configuration to file. +// It will acquire a read-lock on the global options registry +// lock and must lock each option! +func SaveConfig() error { + optionsLock.RLock() + defer optionsLock.RUnlock() + + // check if persistence is configured + if configFilePath == "" { + return nil + } + + // extract values + activeValues := make(map[string]interface{}) + for key, option := range options { + // we cannot immedately unlock the option afger + // getData() because someone could lock and change it + // while we are marshaling the value (i.e. for string slices). + // We NEED to keep the option locks until we finsihed. + option.Lock() + defer option.Unlock() + + if option.activeValue != nil { + activeValues[key] = option.activeValue.getData(option) + } + } + + // convert to JSON + data, err := MapToJSON(activeValues) + if err != nil { + log.Errorf("config: failed to save config: %s", err) + return err + } + + // write file + return os.WriteFile(configFilePath, data, 0o0600) +} + +// JSONToMap parses and flattens a hierarchical json object. +func JSONToMap(jsonData []byte) (map[string]interface{}, error) { + loaded := make(map[string]interface{}) + err := json.Unmarshal(jsonData, &loaded) + if err != nil { + return nil, err + } + + return Flatten(loaded), nil +} + +// Flatten returns a flattened copy of the given hierarchical config. +func Flatten(config map[string]interface{}) (flattenedConfig map[string]interface{}) { + flattenedConfig = make(map[string]interface{}) + flattenMap(flattenedConfig, config, "") + return flattenedConfig +} + +func flattenMap(rootMap, subMap map[string]interface{}, subKey string) { + for key, entry := range subMap { + + // get next level key + subbedKey := path.Join(subKey, key) + + // check for next subMap + nextSub, ok := entry.(map[string]interface{}) + if ok { + flattenMap(rootMap, nextSub, subbedKey) + } else { + // only set if not on root level + rootMap[subbedKey] = entry + } + } +} + +// MapToJSON expands a flattened map and returns it as json. +func MapToJSON(config map[string]interface{}) ([]byte, error) { + return json.MarshalIndent(Expand(config), "", " ") +} + +// Expand returns a hierarchical copy of the given flattened config. +func Expand(flattenedConfig map[string]interface{}) (config map[string]interface{}) { + config = make(map[string]interface{}) + for key, entry := range flattenedConfig { + PutValueIntoHierarchicalConfig(config, key, entry) + } + return config +} + +// PutValueIntoHierarchicalConfig injects a configuration entry into an hierarchical config map. Conflicting entries will be replaced. +func PutValueIntoHierarchicalConfig(config map[string]interface{}, key string, value interface{}) { + parts := strings.Split(key, "/") + + // create/check maps for all parts except the last one + subMap := config + for i, part := range parts { + if i == len(parts)-1 { + // do not process the last part, + // which is not a map, but the value key itself + break + } + + var nextSubMap map[string]interface{} + // get value + value, ok := subMap[part] + if !ok { + // create new map and assign it + nextSubMap = make(map[string]interface{}) + subMap[part] = nextSubMap + } else { + nextSubMap, ok = value.(map[string]interface{}) + if !ok { + // create new map and assign it + nextSubMap = make(map[string]interface{}) + subMap[part] = nextSubMap + } + } + + // assign for next parts loop + subMap = nextSubMap + } + + // assign value to last submap + subMap[parts[len(parts)-1]] = value +} + +// CleanFlattenedConfig removes all inexistent configuration options from the given flattened config map. +func CleanFlattenedConfig(flattenedConfig map[string]interface{}) { + optionsLock.RLock() + defer optionsLock.RUnlock() + + for key := range flattenedConfig { + _, ok := options[key] + if !ok { + delete(flattenedConfig, key) + } + } +} + +// CleanHierarchicalConfig removes all inexistent configuration options from the given hierarchical config map. +func CleanHierarchicalConfig(config map[string]interface{}) { + optionsLock.RLock() + defer optionsLock.RUnlock() + + cleanSubMap(config, "") +} + +func cleanSubMap(subMap map[string]interface{}, subKey string) (empty bool) { + var foundValid int + for key, value := range subMap { + value, ok := value.(map[string]interface{}) + if ok { + // we found another section + isEmpty := cleanSubMap(value, path.Join(subKey, key)) + if isEmpty { + delete(subMap, key) + } else { + foundValid++ + } + continue + } + + // we found an option value + if strings.Contains(key, "/") { + delete(subMap, key) + } else { + _, ok := options[path.Join(subKey, key)] + if ok { + foundValid++ + } else { + delete(subMap, key) + } + } + } + return foundValid == 0 +} diff --git a/base/config/persistence_test.go b/base/config/persistence_test.go new file mode 100644 index 000000000..982835c3f --- /dev/null +++ b/base/config/persistence_test.go @@ -0,0 +1,97 @@ +package config + +import ( + "bytes" + "encoding/json" + "testing" +) + +var ( + jsonData = `{ + "a": "b", + "c": { + "d": "e", + "f": "g", + "h": { + "i": "j", + "k": "l", + "m": { + "n": "o" + } + } + }, + "p": "q" +}` + jsonBytes = []byte(jsonData) + + mapData = map[string]interface{}{ + "a": "b", + "p": "q", + "c/d": "e", + "c/f": "g", + "c/h/i": "j", + "c/h/k": "l", + "c/h/m/n": "o", + } +) + +func TestJSONMapConversion(t *testing.T) { + t.Parallel() + + // convert to json + j, err := MapToJSON(mapData) + if err != nil { + t.Fatal(err) + } + + // check if to json matches + if !bytes.Equal(jsonBytes, j) { + t.Errorf("json does not match, got %s", j) + } + + // convert to map + m, err := JSONToMap(jsonBytes) + if err != nil { + t.Fatal(err) + } + + // and back + j2, err := MapToJSON(m) + if err != nil { + t.Fatal(err) + } + + // check if double convert matches + if !bytes.Equal(jsonBytes, j2) { + t.Errorf("json does not match, got %s", j) + } +} + +func TestConfigCleaning(t *testing.T) { + t.Parallel() + + // load + configFlat, err := JSONToMap(jsonBytes) + if err != nil { + t.Fatal(err) + } + + // clean everything + CleanFlattenedConfig(configFlat) + if len(configFlat) != 0 { + t.Errorf("should be empty: %+v", configFlat) + } + + // load manuall for hierarchical config + configHier := make(map[string]interface{}) + err = json.Unmarshal(jsonBytes, &configHier) + if err != nil { + t.Fatal(err) + } + + // clean everything + CleanHierarchicalConfig(configHier) + if len(configHier) != 0 { + t.Errorf("should be empty: %+v", configHier) + } +} diff --git a/base/config/perspective.go b/base/config/perspective.go new file mode 100644 index 000000000..7fd6a62cc --- /dev/null +++ b/base/config/perspective.go @@ -0,0 +1,133 @@ +package config + +import ( + "fmt" + + "github.com/safing/portmaster/base/log" +) + +// Perspective is a view on configuration data without interfering with the configuration system. +type Perspective struct { + config map[string]*perspectiveOption +} + +type perspectiveOption struct { + option *Option + valueCache *valueCache +} + +// NewPerspective parses the given config and returns it as a new perspective. +func NewPerspective(config map[string]interface{}) (*Perspective, error) { + // flatten config structure + config = Flatten(config) + + perspective := &Perspective{ + config: make(map[string]*perspectiveOption), + } + var firstErr error + var errCnt int + + optionsLock.RLock() +optionsLoop: + for key, option := range options { + // get option key from config + configValue, ok := config[key] + if !ok { + continue + } + // migrate value + configValue = migrateValue(option, configValue) + // validate value + valueCache, err := validateValue(option, configValue) + if err != nil { + errCnt++ + if firstErr == nil { + firstErr = err + } + continue optionsLoop + } + + // add to perspective + perspective.config[key] = &perspectiveOption{ + option: option, + valueCache: valueCache, + } + } + optionsLock.RUnlock() + + if firstErr != nil { + if errCnt > 0 { + return perspective, fmt.Errorf("encountered %d errors, first was: %w", errCnt, firstErr) + } + return perspective, firstErr + } + + return perspective, nil +} + +func (p *Perspective) getPerspectiveValueCache(name string, requestedType OptionType) *valueCache { + // get option + pOption, ok := p.config[name] + if !ok { + // check if option exists at all + if _, err := GetOption(name); err != nil { + log.Errorf("config: request for unregistered option: %s", name) + } + return nil + } + + // check type + if requestedType != pOption.option.OptType && requestedType != optTypeAny { + log.Errorf("config: bad type: requested %s as %s, but is %s", name, getTypeName(requestedType), getTypeName(pOption.option.OptType)) + return nil + } + + // check release level + if pOption.option.ReleaseLevel > getReleaseLevel() { + return nil + } + + return pOption.valueCache +} + +// Has returns whether the given option is set in the perspective. +func (p *Perspective) Has(name string) bool { + valueCache := p.getPerspectiveValueCache(name, optTypeAny) + return valueCache != nil +} + +// GetAsString returns a function that returns the wanted string with high performance. +func (p *Perspective) GetAsString(name string) (value string, ok bool) { + valueCache := p.getPerspectiveValueCache(name, OptTypeString) + if valueCache != nil { + return valueCache.stringVal, true + } + return "", false +} + +// GetAsStringArray returns a function that returns the wanted string with high performance. +func (p *Perspective) GetAsStringArray(name string) (value []string, ok bool) { + valueCache := p.getPerspectiveValueCache(name, OptTypeStringArray) + if valueCache != nil { + return valueCache.stringArrayVal, true + } + return nil, false +} + +// GetAsInt returns a function that returns the wanted int with high performance. +func (p *Perspective) GetAsInt(name string) (value int64, ok bool) { + valueCache := p.getPerspectiveValueCache(name, OptTypeInt) + if valueCache != nil { + return valueCache.intVal, true + } + return 0, false +} + +// GetAsBool returns a function that returns the wanted int with high performance. +func (p *Perspective) GetAsBool(name string) (value bool, ok bool) { + valueCache := p.getPerspectiveValueCache(name, OptTypeBool) + if valueCache != nil { + return valueCache.boolVal, true + } + return false, false +} diff --git a/base/config/registry.go b/base/config/registry.go new file mode 100644 index 000000000..04b0dbf63 --- /dev/null +++ b/base/config/registry.go @@ -0,0 +1,106 @@ +package config + +import ( + "fmt" + "regexp" + "sort" + "strings" + "sync" +) + +var ( + optionsLock sync.RWMutex + options = make(map[string]*Option) +) + +// ForEachOption calls fn for each defined option. If fn returns +// and error the iteration is stopped and the error is returned. +// Note that ForEachOption does not guarantee a stable order of +// iteration between multiple calles. ForEachOption does NOT lock +// opt when calling fn. +func ForEachOption(fn func(opt *Option) error) error { + optionsLock.RLock() + defer optionsLock.RUnlock() + + for _, opt := range options { + if err := fn(opt); err != nil { + return err + } + } + return nil +} + +// ExportOptions exports the registered options. The returned data must be +// treated as immutable. +// The data does not include the current active or default settings. +func ExportOptions() []*Option { + optionsLock.RLock() + defer optionsLock.RUnlock() + + // Copy the map into a slice. + opts := make([]*Option, 0, len(options)) + for _, opt := range options { + opts = append(opts, opt) + } + + sort.Sort(sortByKey(opts)) + return opts +} + +// GetOption returns the option with name or an error +// if the option does not exist. The caller should lock +// the returned option itself for further processing. +func GetOption(name string) (*Option, error) { + optionsLock.RLock() + defer optionsLock.RUnlock() + + opt, ok := options[name] + if !ok { + return nil, fmt.Errorf("option %q does not exist", name) + } + return opt, nil +} + +// Register registers a new configuration option. +func Register(option *Option) error { + if option.Name == "" { + return fmt.Errorf("failed to register option: please set option.Name") + } + if option.Key == "" { + return fmt.Errorf("failed to register option: please set option.Key") + } + if option.Description == "" { + return fmt.Errorf("failed to register option: please set option.Description") + } + if option.OptType == 0 { + return fmt.Errorf("failed to register option: please set option.OptType") + } + + if option.ValidationRegex == "" && option.PossibleValues != nil { + values := make([]string, len(option.PossibleValues)) + for idx, val := range option.PossibleValues { + values[idx] = fmt.Sprintf("%v", val.Value) + } + option.ValidationRegex = fmt.Sprintf("^(%s)$", strings.Join(values, "|")) + } + + var err error + if option.ValidationRegex != "" { + option.compiledRegex, err = regexp.Compile(option.ValidationRegex) + if err != nil { + return fmt.Errorf("config: could not compile option.ValidationRegex: %w", err) + } + } + + var vErr *ValidationError + option.activeFallbackValue, vErr = validateValue(option, option.DefaultValue) + if vErr != nil { + return fmt.Errorf("config: invalid default value: %w", vErr) + } + + optionsLock.Lock() + defer optionsLock.Unlock() + options[option.Key] = option + + return nil +} diff --git a/base/config/registry_test.go b/base/config/registry_test.go new file mode 100644 index 000000000..c64aa2db9 --- /dev/null +++ b/base/config/registry_test.go @@ -0,0 +1,49 @@ +package config + +import ( + "testing" +) + +func TestRegistry(t *testing.T) { //nolint:paralleltest + // reset + options = make(map[string]*Option) + + if err := Register(&Option{ + Name: "name", + Key: "key", + Description: "description", + ReleaseLevel: ReleaseLevelStable, + ExpertiseLevel: ExpertiseLevelUser, + OptType: OptTypeString, + DefaultValue: "water", + ValidationRegex: "^(banana|water)$", + }); err != nil { + t.Error(err) + } + + if err := Register(&Option{ + Name: "name", + Key: "key", + Description: "description", + ReleaseLevel: ReleaseLevelStable, + ExpertiseLevel: ExpertiseLevelUser, + OptType: 0, + DefaultValue: "default", + ValidationRegex: "^[A-Z][a-z]+$", + }); err == nil { + t.Error("should fail") + } + + if err := Register(&Option{ + Name: "name", + Key: "key", + Description: "description", + ReleaseLevel: ReleaseLevelStable, + ExpertiseLevel: ExpertiseLevelUser, + OptType: OptTypeString, + DefaultValue: "default", + ValidationRegex: "[", + }); err == nil { + t.Error("should fail") + } +} diff --git a/base/config/release.go b/base/config/release.go new file mode 100644 index 000000000..2b50e1677 --- /dev/null +++ b/base/config/release.go @@ -0,0 +1,101 @@ +package config + +import ( + "sync/atomic" + + "github.com/tevino/abool" +) + +// ReleaseLevel is used to define the maturity of a +// configuration setting. +type ReleaseLevel uint8 + +// Release Level constants. +const ( + ReleaseLevelStable ReleaseLevel = 0 + ReleaseLevelBeta ReleaseLevel = 1 + ReleaseLevelExperimental ReleaseLevel = 2 + + ReleaseLevelNameStable = "stable" + ReleaseLevelNameBeta = "beta" + ReleaseLevelNameExperimental = "experimental" + + releaseLevelKey = "core/releaseLevel" +) + +var ( + releaseLevel = new(int32) + releaseLevelOption *Option + releaseLevelOptionFlag = abool.New() +) + +func init() { + registerReleaseLevelOption() +} + +func registerReleaseLevelOption() { + releaseLevelOption = &Option{ + Name: "Feature Stability", + Key: releaseLevelKey, + Description: `May break things. Decide if you want to experiment with unstable features. "Beta" has been tested roughly by the Safing team while "Experimental" is really raw. When "Beta" or "Experimental" are disabled, their settings use the default again.`, + OptType: OptTypeString, + ExpertiseLevel: ExpertiseLevelDeveloper, + ReleaseLevel: ReleaseLevelStable, + DefaultValue: ReleaseLevelNameStable, + Annotations: Annotations{ + DisplayOrderAnnotation: -8, + DisplayHintAnnotation: DisplayHintOneOf, + CategoryAnnotation: "Updates", + }, + PossibleValues: []PossibleValue{ + { + Name: "Stable", + Value: ReleaseLevelNameStable, + Description: "Only show stable features.", + }, + { + Name: "Beta", + Value: ReleaseLevelNameBeta, + Description: "Show stable and beta features.", + }, + { + Name: "Experimental", + Value: ReleaseLevelNameExperimental, + Description: "Show all features", + }, + }, + } + + err := Register(releaseLevelOption) + if err != nil { + panic(err) + } + + releaseLevelOptionFlag.Set() +} + +func updateReleaseLevel() { + // get value + value := releaseLevelOption.activeFallbackValue + if releaseLevelOption.activeValue != nil { + value = releaseLevelOption.activeValue + } + if releaseLevelOption.activeDefaultValue != nil { + value = releaseLevelOption.activeDefaultValue + } + // set atomic value + switch value.stringVal { + case ReleaseLevelNameStable: + atomic.StoreInt32(releaseLevel, int32(ReleaseLevelStable)) + case ReleaseLevelNameBeta: + atomic.StoreInt32(releaseLevel, int32(ReleaseLevelBeta)) + case ReleaseLevelNameExperimental: + atomic.StoreInt32(releaseLevel, int32(ReleaseLevelExperimental)) + default: + atomic.StoreInt32(releaseLevel, int32(ReleaseLevelStable)) + } +} + +func getReleaseLevel() ReleaseLevel { + return ReleaseLevel(atomic.LoadInt32(releaseLevel)) +} diff --git a/base/config/set.go b/base/config/set.go new file mode 100644 index 000000000..2c40ca213 --- /dev/null +++ b/base/config/set.go @@ -0,0 +1,235 @@ +package config + +import ( + "errors" + "sync" + + "github.com/tevino/abool" +) + +var ( + // ErrInvalidJSON is returned by SetConfig and SetDefaultConfig if they receive invalid json. + ErrInvalidJSON = errors.New("json string invalid") + + // ErrInvalidOptionType is returned by SetConfigOption and SetDefaultConfigOption if given an unsupported option type. + ErrInvalidOptionType = errors.New("invalid option value type") + + validityFlag = abool.NewBool(true) + validityFlagLock sync.RWMutex +) + +// getValidityFlag returns a flag that signifies if the configuration has been changed. This flag must not be changed, only read. +func getValidityFlag() *abool.AtomicBool { + validityFlagLock.RLock() + defer validityFlagLock.RUnlock() + return validityFlag +} + +// signalChanges marks the configs validtityFlag as dirty and eventually +// triggers a config change event. +func signalChanges() { + // reset validity flag + validityFlagLock.Lock() + validityFlag.SetTo(false) + validityFlag = abool.NewBool(true) + validityFlagLock.Unlock() + + module.TriggerEvent(ChangeEvent, nil) +} + +// ValidateConfig validates the given configuration and returns all validation +// errors as well as whether the given configuration contains unknown keys. +func ValidateConfig(newValues map[string]interface{}) (validationErrors []*ValidationError, requiresRestart bool, containsUnknown bool) { + // RLock the options because we are not adding or removing + // options from the registration but rather only checking the + // options value which is guarded by the option's lock itself. + optionsLock.RLock() + defer optionsLock.RUnlock() + + var checked int + for key, option := range options { + newValue, ok := newValues[key] + if ok { + checked++ + + func() { + option.Lock() + defer option.Unlock() + + newValue = migrateValue(option, newValue) + _, err := validateValue(option, newValue) + if err != nil { + validationErrors = append(validationErrors, err) + } + + if option.RequiresRestart { + requiresRestart = true + } + }() + } + } + + return validationErrors, requiresRestart, checked < len(newValues) +} + +// ReplaceConfig sets the (prioritized) user defined config. +func ReplaceConfig(newValues map[string]interface{}) (validationErrors []*ValidationError, requiresRestart bool) { + // RLock the options because we are not adding or removing + // options from the registration but rather only update the + // options value which is guarded by the option's lock itself. + optionsLock.RLock() + defer optionsLock.RUnlock() + + for key, option := range options { + newValue, ok := newValues[key] + + func() { + option.Lock() + defer option.Unlock() + + option.activeValue = nil + if ok { + newValue = migrateValue(option, newValue) + valueCache, err := validateValue(option, newValue) + if err == nil { + option.activeValue = valueCache + } else { + validationErrors = append(validationErrors, err) + } + } + handleOptionUpdate(option, true) + + if option.RequiresRestart { + requiresRestart = true + } + }() + } + + signalChanges() + + return validationErrors, requiresRestart +} + +// ReplaceDefaultConfig sets the (fallback) default config. +func ReplaceDefaultConfig(newValues map[string]interface{}) (validationErrors []*ValidationError, requiresRestart bool) { + // RLock the options because we are not adding or removing + // options from the registration but rather only update the + // options value which is guarded by the option's lock itself. + optionsLock.RLock() + defer optionsLock.RUnlock() + + for key, option := range options { + newValue, ok := newValues[key] + + func() { + option.Lock() + defer option.Unlock() + + option.activeDefaultValue = nil + if ok { + newValue = migrateValue(option, newValue) + valueCache, err := validateValue(option, newValue) + if err == nil { + option.activeDefaultValue = valueCache + } else { + validationErrors = append(validationErrors, err) + } + } + handleOptionUpdate(option, true) + + if option.RequiresRestart { + requiresRestart = true + } + }() + } + + signalChanges() + + return validationErrors, requiresRestart +} + +// SetConfigOption sets a single value in the (prioritized) user defined config. +func SetConfigOption(key string, value any) error { + return setConfigOption(key, value, true) +} + +func setConfigOption(key string, value any, push bool) (err error) { + option, err := GetOption(key) + if err != nil { + return err + } + + option.Lock() + if value == nil { + option.activeValue = nil + } else { + value = migrateValue(option, value) + valueCache, vErr := validateValue(option, value) + if vErr == nil { + option.activeValue = valueCache + } else { + err = vErr + } + } + + // Add the "restart pending" annotation if the settings requires a restart. + if option.RequiresRestart { + option.setAnnotation(RestartPendingAnnotation, true) + } + + handleOptionUpdate(option, push) + option.Unlock() + + if err != nil { + return err + } + + // finalize change, activate triggers + signalChanges() + + return SaveConfig() +} + +// SetDefaultConfigOption sets a single value in the (fallback) default config. +func SetDefaultConfigOption(key string, value interface{}) error { + return setDefaultConfigOption(key, value, true) +} + +func setDefaultConfigOption(key string, value interface{}, push bool) (err error) { + option, err := GetOption(key) + if err != nil { + return err + } + + option.Lock() + if value == nil { + option.activeDefaultValue = nil + } else { + value = migrateValue(option, value) + valueCache, vErr := validateValue(option, value) + if vErr == nil { + option.activeDefaultValue = valueCache + } else { + err = vErr + } + } + + // Add the "restart pending" annotation if the settings requires a restart. + if option.RequiresRestart { + option.setAnnotation(RestartPendingAnnotation, true) + } + + handleOptionUpdate(option, push) + option.Unlock() + + if err != nil { + return err + } + + // finalize change, activate triggers + signalChanges() + + // Do not save the configuration, as it only saves the active values, not the + // active default value. + return nil +} diff --git a/base/config/set_test.go b/base/config/set_test.go new file mode 100644 index 000000000..9f52a04ad --- /dev/null +++ b/base/config/set_test.go @@ -0,0 +1,193 @@ +//nolint:goconst +package config + +import "testing" + +func TestLayersGetters(t *testing.T) { //nolint:paralleltest + // reset + options = make(map[string]*Option) + + mapData, err := JSONToMap([]byte(` + { + "monkey": "1", + "elephant": 2, + "zebras": { + "zebra": ["black", "white"], + "weird_zebra": ["black", -1] + }, + "env": { + "hot": true + } + } + `)) + if err != nil { + t.Fatal(err) + } + + validationErrors, _ := ReplaceConfig(mapData) + if len(validationErrors) > 0 { + t.Fatalf("%d errors, first: %s", len(validationErrors), validationErrors[0].Error()) + } + + // Test missing values + + missingString := GetAsString("missing", "fallback") + if missingString() != "fallback" { + t.Error("expected fallback value: fallback") + } + + missingStringArray := GetAsStringArray("missing", []string{"fallback"}) + if len(missingStringArray()) != 1 || missingStringArray()[0] != "fallback" { + t.Error("expected fallback value: [fallback]") + } + + missingInt := GetAsInt("missing", -1) + if missingInt() != -1 { + t.Error("expected fallback value: -1") + } + + missingBool := GetAsBool("missing", false) + if missingBool() { + t.Error("expected fallback value: false") + } + + // Test value mismatch + + notString := GetAsString("elephant", "fallback") + if notString() != "fallback" { + t.Error("expected fallback value: fallback") + } + + notStringArray := GetAsStringArray("elephant", []string{"fallback"}) + if len(notStringArray()) != 1 || notStringArray()[0] != "fallback" { + t.Error("expected fallback value: [fallback]") + } + + mixedStringArray := GetAsStringArray("zebras/weird_zebra", []string{"fallback"}) + if len(mixedStringArray()) != 1 || mixedStringArray()[0] != "fallback" { + t.Error("expected fallback value: [fallback]") + } + + notInt := GetAsInt("monkey", -1) + if notInt() != -1 { + t.Error("expected fallback value: -1") + } + + notBool := GetAsBool("monkey", false) + if notBool() { + t.Error("expected fallback value: false") + } +} + +func TestLayersSetters(t *testing.T) { //nolint:paralleltest + // reset + options = make(map[string]*Option) + + _ = Register(&Option{ + Name: "name", + Key: "monkey", + Description: "description", + ReleaseLevel: ReleaseLevelStable, + ExpertiseLevel: ExpertiseLevelUser, + OptType: OptTypeString, + DefaultValue: "banana", + ValidationRegex: "^(banana|water)$", + }) + _ = Register(&Option{ + Name: "name", + Key: "zebras/zebra", + Description: "description", + ReleaseLevel: ReleaseLevelStable, + ExpertiseLevel: ExpertiseLevelUser, + OptType: OptTypeStringArray, + DefaultValue: []string{"black", "white"}, + ValidationRegex: "^[a-z]+$", + }) + _ = Register(&Option{ + Name: "name", + Key: "elephant", + Description: "description", + ReleaseLevel: ReleaseLevelStable, + ExpertiseLevel: ExpertiseLevelUser, + OptType: OptTypeInt, + DefaultValue: 2, + ValidationRegex: "", + }) + _ = Register(&Option{ + Name: "name", + Key: "hot", + Description: "description", + ReleaseLevel: ReleaseLevelStable, + ExpertiseLevel: ExpertiseLevelUser, + OptType: OptTypeBool, + DefaultValue: true, + ValidationRegex: "", + }) + + // correct types + if err := SetConfigOption("monkey", "banana"); err != nil { + t.Error(err) + } + if err := SetConfigOption("zebras/zebra", []string{"black", "white"}); err != nil { + t.Error(err) + } + if err := SetDefaultConfigOption("elephant", 2); err != nil { + t.Error(err) + } + if err := SetDefaultConfigOption("hot", true); err != nil { + t.Error(err) + } + + // incorrect types + if err := SetConfigOption("monkey", []string{"black", "white"}); err == nil { + t.Error("should fail") + } + if err := SetConfigOption("zebras/zebra", 2); err == nil { + t.Error("should fail") + } + if err := SetDefaultConfigOption("elephant", true); err == nil { + t.Error("should fail") + } + if err := SetDefaultConfigOption("hot", "banana"); err == nil { + t.Error("should fail") + } + if err := SetDefaultConfigOption("hot", []byte{0}); err == nil { + t.Error("should fail") + } + + // validation fail + if err := SetConfigOption("monkey", "dirt"); err == nil { + t.Error("should fail") + } + if err := SetConfigOption("zebras/zebra", []string{"Element649"}); err == nil { + t.Error("should fail") + } + + // unregistered checking + if err := SetConfigOption("invalid", "banana"); err == nil { + t.Error("should fail") + } + if err := SetConfigOption("invalid", []string{"black", "white"}); err == nil { + t.Error("should fail") + } + if err := SetConfigOption("invalid", 2); err == nil { + t.Error("should fail") + } + if err := SetConfigOption("invalid", true); err == nil { + t.Error("should fail") + } + if err := SetConfigOption("invalid", []byte{0}); err == nil { + t.Error("should fail") + } + + // delete + if err := SetConfigOption("monkey", nil); err != nil { + t.Error(err) + } + if err := SetDefaultConfigOption("elephant", nil); err != nil { + t.Error(err) + } + if err := SetDefaultConfigOption("invalid_delete", nil); err == nil { + t.Error("should fail") + } +} diff --git a/base/config/validate.go b/base/config/validate.go new file mode 100644 index 000000000..8d3408863 --- /dev/null +++ b/base/config/validate.go @@ -0,0 +1,239 @@ +package config + +import ( + "errors" + "fmt" + "math" + "reflect" + + "github.com/safing/portmaster/base/log" +) + +type valueCache struct { + stringVal string + stringArrayVal []string + intVal int64 + boolVal bool +} + +func (vc *valueCache) getData(opt *Option) interface{} { + switch opt.OptType { + case OptTypeBool: + return vc.boolVal + case OptTypeInt: + return vc.intVal + case OptTypeString: + return vc.stringVal + case OptTypeStringArray: + return vc.stringArrayVal + case optTypeAny: + return nil + default: + return nil + } +} + +// isAllowedPossibleValue checks if value is defined as a PossibleValue +// in opt. If there are not possible values defined value is considered +// allowed and nil is returned. isAllowedPossibleValue ensure the actual +// value is an allowed primitiv value by using reflection to convert +// value and each PossibleValue to a comparable primitiv if possible. +// In case of complex value types isAllowedPossibleValue uses +// reflect.DeepEqual as a fallback. +func isAllowedPossibleValue(opt *Option, value interface{}) error { + if opt.PossibleValues == nil { + return nil + } + + for _, val := range opt.PossibleValues { + compareAgainst := val.Value + valueType := reflect.TypeOf(value) + + // loading int's from the configuration JSON does not preserve the correct type + // as we get float64 instead. Make sure to convert them before. + if reflect.TypeOf(val.Value).ConvertibleTo(valueType) { + compareAgainst = reflect.ValueOf(val.Value).Convert(valueType).Interface() + } + if compareAgainst == value { + return nil + } + + if reflect.DeepEqual(val.Value, value) { + return nil + } + } + + return errors.New("value is not allowed") +} + +// migrateValue runs all value migrations. +func migrateValue(option *Option, value any) any { + for _, migration := range option.Migrations { + newValue := migration(option, value) + if newValue != value { + log.Debugf("config: migrated %s value from %v to %v", option.Key, value, newValue) + } + value = newValue + } + return value +} + +// validateValue ensures that value matches the expected type of option. +// It does not create a copy of the value! +func validateValue(option *Option, value interface{}) (*valueCache, *ValidationError) { //nolint:gocyclo + if option.OptType != OptTypeStringArray { + if err := isAllowedPossibleValue(option, value); err != nil { + return nil, &ValidationError{ + Option: option.copyOrNil(), + Err: err, + } + } + } + + var validated *valueCache + switch v := value.(type) { + case string: + if option.OptType != OptTypeString { + return nil, invalid(option, "expected type %s, got type %T", getTypeName(option.OptType), v) + } + if option.compiledRegex != nil { + if !option.compiledRegex.MatchString(v) { + return nil, invalid(option, "did not match validation regex") + } + } + validated = &valueCache{stringVal: v} + case []interface{}: + vConverted := make([]string, len(v)) + for pos, entry := range v { + s, ok := entry.(string) + if !ok { + return nil, invalid(option, "entry #%d is not a string", pos+1) + } + vConverted[pos] = s + } + // Call validation function again with converted value. + var vErr *ValidationError + validated, vErr = validateValue(option, vConverted) + if vErr != nil { + return nil, vErr + } + case []string: + if option.OptType != OptTypeStringArray { + return nil, invalid(option, "expected type %s, got type %T", getTypeName(option.OptType), v) + } + if option.compiledRegex != nil { + for pos, entry := range v { + if !option.compiledRegex.MatchString(entry) { + return nil, invalid(option, "entry #%d did not match validation regex", pos+1) + } + + if err := isAllowedPossibleValue(option, entry); err != nil { + return nil, invalid(option, "entry #%d is not allowed", pos+1) + } + } + } + validated = &valueCache{stringArrayVal: v} + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, float32, float64: + // uint64 is omitted, as it does not fit in a int64 + if option.OptType != OptTypeInt { + return nil, invalid(option, "expected type %s, got type %T", getTypeName(option.OptType), v) + } + if option.compiledRegex != nil { + // we need to use %v here so we handle float and int correctly. + if !option.compiledRegex.MatchString(fmt.Sprintf("%v", v)) { + return nil, invalid(option, "did not match validation regex") + } + } + switch v := value.(type) { + case int: + validated = &valueCache{intVal: int64(v)} + case int8: + validated = &valueCache{intVal: int64(v)} + case int16: + validated = &valueCache{intVal: int64(v)} + case int32: + validated = &valueCache{intVal: int64(v)} + case int64: + validated = &valueCache{intVal: v} + case uint: + validated = &valueCache{intVal: int64(v)} + case uint8: + validated = &valueCache{intVal: int64(v)} + case uint16: + validated = &valueCache{intVal: int64(v)} + case uint32: + validated = &valueCache{intVal: int64(v)} + case float32: + // convert if float has no decimals + if math.Remainder(float64(v), 1) == 0 { + validated = &valueCache{intVal: int64(v)} + } else { + return nil, invalid(option, "failed to convert float32 to int64") + } + case float64: + // convert if float has no decimals + if math.Remainder(v, 1) == 0 { + validated = &valueCache{intVal: int64(v)} + } else { + return nil, invalid(option, "failed to convert float64 to int64") + } + default: + return nil, invalid(option, "internal error") + } + case bool: + if option.OptType != OptTypeBool { + return nil, invalid(option, "expected type %s, got type %T", getTypeName(option.OptType), v) + } + validated = &valueCache{boolVal: v} + default: + return nil, invalid(option, "invalid option value type: %T", value) + } + + // Check if there is an additional function to validate the value. + if option.ValidationFunc != nil { + var err error + switch option.OptType { + case optTypeAny: + err = errors.New("internal error") + case OptTypeString: + err = option.ValidationFunc(validated.stringVal) + case OptTypeStringArray: + err = option.ValidationFunc(validated.stringArrayVal) + case OptTypeInt: + err = option.ValidationFunc(validated.intVal) + case OptTypeBool: + err = option.ValidationFunc(validated.boolVal) + } + if err != nil { + return nil, &ValidationError{ + Option: option.copyOrNil(), + Err: err, + } + } + } + + return validated, nil +} + +// ValidationError error holds details about a config option value validation error. +type ValidationError struct { + Option *Option + Err error +} + +// Error returns the formatted error. +func (ve *ValidationError) Error() string { + return fmt.Sprintf("validation of %s failed: %s", ve.Option.Key, ve.Err) +} + +// Unwrap returns the wrapped error. +func (ve *ValidationError) Unwrap() error { + return ve.Err +} + +func invalid(option *Option, format string, a ...interface{}) *ValidationError { + return &ValidationError{ + Option: option.copyOrNil(), + Err: fmt.Errorf(format, a...), + } +} diff --git a/base/config/validity.go b/base/config/validity.go new file mode 100644 index 000000000..925d306c5 --- /dev/null +++ b/base/config/validity.go @@ -0,0 +1,32 @@ +package config + +import ( + "github.com/tevino/abool" +) + +// ValidityFlag is a flag that signifies if the configuration has been changed. It is not safe for concurrent use. +type ValidityFlag struct { + flag *abool.AtomicBool +} + +// NewValidityFlag returns a flag that signifies if the configuration has been changed. +// It always starts out as invalid. Refresh to start with the current value. +func NewValidityFlag() *ValidityFlag { + vf := &ValidityFlag{ + flag: abool.New(), + } + return vf +} + +// IsValid returns if the configuration is still valid. +func (vf *ValidityFlag) IsValid() bool { + return vf.flag.IsSet() +} + +// Refresh refreshes the flag and makes it reusable. +func (vf *ValidityFlag) Refresh() { + validityFlagLock.RLock() + defer validityFlagLock.RUnlock() + + vf.flag = validityFlag +} diff --git a/base/container/container.go b/base/container/container.go new file mode 100644 index 000000000..775fc2053 --- /dev/null +++ b/base/container/container.go @@ -0,0 +1,368 @@ +package container + +import ( + "errors" + "io" + + "github.com/safing/portmaster/base/formats/varint" +) + +// Container is []byte sclie on steroids, allowing for quick data appending, prepending and fetching. +type Container struct { + compartments [][]byte + offset int + err error +} + +// Data Handling + +// NewContainer is DEPRECATED, please use New(), it's the same thing. +func NewContainer(data ...[]byte) *Container { + return &Container{ + compartments: data, + } +} + +// New creates a new container with an optional initial []byte slice. Data will NOT be copied. +func New(data ...[]byte) *Container { + return &Container{ + compartments: data, + } +} + +// Prepend prepends data. Data will NOT be copied. +func (c *Container) Prepend(data []byte) { + if c.offset < 1 { + c.renewCompartments() + } + c.offset-- + c.compartments[c.offset] = data +} + +// Append appends the given data. Data will NOT be copied. +func (c *Container) Append(data []byte) { + c.compartments = append(c.compartments, data) +} + +// PrependNumber prepends a number (varint encoded). +func (c *Container) PrependNumber(n uint64) { + c.Prepend(varint.Pack64(n)) +} + +// AppendNumber appends a number (varint encoded). +func (c *Container) AppendNumber(n uint64) { + c.compartments = append(c.compartments, varint.Pack64(n)) +} + +// PrependInt prepends an int (varint encoded). +func (c *Container) PrependInt(n int) { + c.Prepend(varint.Pack64(uint64(n))) +} + +// AppendInt appends an int (varint encoded). +func (c *Container) AppendInt(n int) { + c.compartments = append(c.compartments, varint.Pack64(uint64(n))) +} + +// AppendAsBlock appends the length of the data and the data itself. Data will NOT be copied. +func (c *Container) AppendAsBlock(data []byte) { + c.AppendNumber(uint64(len(data))) + c.Append(data) +} + +// PrependAsBlock prepends the length of the data and the data itself. Data will NOT be copied. +func (c *Container) PrependAsBlock(data []byte) { + c.Prepend(data) + c.PrependNumber(uint64(len(data))) +} + +// AppendContainer appends another Container. Data will NOT be copied. +func (c *Container) AppendContainer(data *Container) { + c.compartments = append(c.compartments, data.compartments...) +} + +// AppendContainerAsBlock appends another Container (length and data). Data will NOT be copied. +func (c *Container) AppendContainerAsBlock(data *Container) { + c.AppendNumber(uint64(data.Length())) + c.compartments = append(c.compartments, data.compartments...) +} + +// HoldsData returns true if the Container holds any data. +func (c *Container) HoldsData() bool { + for i := c.offset; i < len(c.compartments); i++ { + if len(c.compartments[i]) > 0 { + return true + } + } + return false +} + +// Length returns the full length of all bytes held by the container. +func (c *Container) Length() (length int) { + for i := c.offset; i < len(c.compartments); i++ { + length += len(c.compartments[i]) + } + return +} + +// Replace replaces all held data with a new data slice. Data will NOT be copied. +func (c *Container) Replace(data []byte) { + c.compartments = [][]byte{data} +} + +// CompileData concatenates all bytes held by the container and returns it as one single []byte slice. Data will NOT be copied and is NOT consumed. +func (c *Container) CompileData() []byte { + if len(c.compartments) != 1 { + newBuf := make([]byte, c.Length()) + copyBuf := newBuf + for i := c.offset; i < len(c.compartments); i++ { + copy(copyBuf, c.compartments[i]) + copyBuf = copyBuf[len(c.compartments[i]):] + } + c.compartments = [][]byte{newBuf} + c.offset = 0 + } + return c.compartments[0] +} + +// Get returns the given amount of bytes. Data MAY be copied and IS consumed. +func (c *Container) Get(n int) ([]byte, error) { + buf := c.Peek(n) + if len(buf) < n { + return nil, errors.New("container: not enough data to return") + } + c.skip(len(buf)) + return buf, nil +} + +// GetAll returns all data. Data MAY be copied and IS consumed. +func (c *Container) GetAll() []byte { + // TODO: Improve. + buf := c.Peek(c.Length()) + c.skip(len(buf)) + return buf +} + +// GetAsContainer returns the given amount of bytes in a new container. Data will NOT be copied and IS consumed. +func (c *Container) GetAsContainer(n int) (*Container, error) { + newC := c.PeekContainer(n) + if newC == nil { + return nil, errors.New("container: not enough data to return") + } + c.skip(n) + return newC, nil +} + +// GetMax returns as much as possible, but the given amount of bytes at maximum. Data MAY be copied and IS consumed. +func (c *Container) GetMax(n int) []byte { + buf := c.Peek(n) + c.skip(len(buf)) + return buf +} + +// WriteToSlice copies data to the give slice until it is full, or the container is empty. It returns the bytes written and if the container is now empty. Data IS copied and IS consumed. +func (c *Container) WriteToSlice(slice []byte) (n int, containerEmptied bool) { + for i := c.offset; i < len(c.compartments); i++ { + copy(slice, c.compartments[i]) + if len(slice) < len(c.compartments[i]) { + // only part was copied + n += len(slice) + c.compartments[i] = c.compartments[i][len(slice):] + c.checkOffset() + return n, false + } + // all was copied + n += len(c.compartments[i]) + slice = slice[len(c.compartments[i]):] + c.compartments[i] = nil + c.offset = i + 1 + } + c.checkOffset() + return n, true +} + +// WriteAllTo writes all the data to the given io.Writer. Data IS NOT copied (but may be by writer) and IS NOT consumed. +func (c *Container) WriteAllTo(writer io.Writer) error { + for i := c.offset; i < len(c.compartments); i++ { + written := 0 + for written < len(c.compartments[i]) { + n, err := writer.Write(c.compartments[i][written:]) + if err != nil { + return err + } + written += n + } + } + return nil +} + +func (c *Container) clean() { + if c.offset > 100 { + c.renewCompartments() + } +} + +func (c *Container) renewCompartments() { + baseLength := len(c.compartments) - c.offset + 5 + newCompartments := make([][]byte, baseLength, baseLength+5) + copy(newCompartments[5:], c.compartments[c.offset:]) + c.compartments = newCompartments + c.offset = 4 +} + +func (c *Container) carbonCopy() *Container { + newC := &Container{ + compartments: make([][]byte, len(c.compartments)), + offset: c.offset, + err: c.err, + } + copy(newC.compartments, c.compartments) + return newC +} + +func (c *Container) checkOffset() { + if c.offset >= len(c.compartments) { + c.offset = len(c.compartments) / 2 + } +} + +// Block Handling + +// PrependLength prepends the current full length of all bytes in the container. +func (c *Container) PrependLength() { + c.Prepend(varint.Pack64(uint64(c.Length()))) +} + +// Peek returns the given amount of bytes. Data MAY be copied and IS NOT consumed. +func (c *Container) Peek(n int) []byte { + // Check requested length. + if n <= 0 { + return nil + } + + // Check if the first slice holds enough data. + if len(c.compartments[c.offset]) >= n { + return c.compartments[c.offset][:n] + } + + // Start gathering data. + slice := make([]byte, n) + copySlice := slice + n = 0 + for i := c.offset; i < len(c.compartments); i++ { + copy(copySlice, c.compartments[i]) + if len(copySlice) <= len(c.compartments[i]) { + n += len(copySlice) + return slice[:n] + } + n += len(c.compartments[i]) + copySlice = copySlice[len(c.compartments[i]):] + } + return slice[:n] +} + +// PeekContainer returns the given amount of bytes in a new container. Data will NOT be copied and IS NOT consumed. +func (c *Container) PeekContainer(n int) (newC *Container) { + // Check requested length. + if n < 0 { + return nil + } else if n == 0 { + return &Container{} + } + + newC = &Container{} + for i := c.offset; i < len(c.compartments); i++ { + if n >= len(c.compartments[i]) { + newC.compartments = append(newC.compartments, c.compartments[i]) + n -= len(c.compartments[i]) + } else { + newC.compartments = append(newC.compartments, c.compartments[i][:n]) + n = 0 + } + } + if n > 0 { + return nil + } + return newC +} + +func (c *Container) skip(n int) { + for i := c.offset; i < len(c.compartments); i++ { + if len(c.compartments[i]) <= n { + n -= len(c.compartments[i]) + c.offset = i + 1 + c.compartments[i] = nil + if n == 0 { + c.checkOffset() + return + } + } else { + c.compartments[i] = c.compartments[i][n:] + c.checkOffset() + return + } + } + c.checkOffset() +} + +// GetNextBlock returns the next block of data defined by a varint. Data MAY be copied and IS consumed. +func (c *Container) GetNextBlock() ([]byte, error) { + blockSize, err := c.GetNextN64() + if err != nil { + return nil, err + } + return c.Get(int(blockSize)) +} + +// GetNextBlockAsContainer returns the next block of data as a Container defined by a varint. Data will NOT be copied and IS consumed. +func (c *Container) GetNextBlockAsContainer() (*Container, error) { + blockSize, err := c.GetNextN64() + if err != nil { + return nil, err + } + return c.GetAsContainer(int(blockSize)) +} + +// GetNextN8 parses and returns a varint of type uint8. +func (c *Container) GetNextN8() (uint8, error) { + buf := c.Peek(2) + num, n, err := varint.Unpack8(buf) + if err != nil { + return 0, err + } + c.skip(n) + return num, nil +} + +// GetNextN16 parses and returns a varint of type uint16. +func (c *Container) GetNextN16() (uint16, error) { + buf := c.Peek(3) + num, n, err := varint.Unpack16(buf) + if err != nil { + return 0, err + } + c.skip(n) + return num, nil +} + +// GetNextN32 parses and returns a varint of type uint32. +func (c *Container) GetNextN32() (uint32, error) { + buf := c.Peek(5) + num, n, err := varint.Unpack32(buf) + if err != nil { + return 0, err + } + c.skip(n) + return num, nil +} + +// GetNextN64 parses and returns a varint of type uint64. +func (c *Container) GetNextN64() (uint64, error) { + buf := c.Peek(10) + num, n, err := varint.Unpack64(buf) + if err != nil { + return 0, err + } + c.skip(n) + return num, nil +} diff --git a/base/container/container_test.go b/base/container/container_test.go new file mode 100644 index 000000000..bc8608b5f --- /dev/null +++ b/base/container/container_test.go @@ -0,0 +1,208 @@ +package container + +import ( + "bytes" + "testing" + + "github.com/safing/portmaster/base/utils" +) + +var ( + testData = []byte("The quick brown fox jumps over the lazy dog") + testDataSplitted = [][]byte{ + []byte("T"), + []byte("he"), + []byte(" qu"), + []byte("ick "), + []byte("brown"), + []byte(" fox j"), + []byte("umps ov"), + []byte("er the l"), + []byte("azy dog"), + } +) + +func TestContainerDataHandling(t *testing.T) { + t.Parallel() + + c1 := New(utils.DuplicateBytes(testData)) + c1c := c1.carbonCopy() + + c2 := New() + for i := 0; i < len(testData); i++ { + oneByte := make([]byte, 1) + c1c.WriteToSlice(oneByte) + c2.Append(oneByte) + } + c2c := c2.carbonCopy() + + c3 := New() + for i := len(c2c.compartments) - 1; i >= c2c.offset; i-- { + c3.Prepend(c2c.compartments[i]) + } + c3c := c3.carbonCopy() + + d4 := make([]byte, len(testData)*2) + n, _ := c3c.WriteToSlice(d4) + d4 = d4[:n] + c3c = c3.carbonCopy() + + d5 := make([]byte, len(testData)) + for i := 0; i < len(testData); i++ { + c3c.WriteToSlice(d5[i : i+1]) + } + + c6 := New() + c6.Replace(testData) + + c7 := New(testDataSplitted[0]) + for i := 1; i < len(testDataSplitted); i++ { + c7.Append(testDataSplitted[i]) + } + + c8 := New(testDataSplitted...) + for i := 0; i < 110; i++ { + c8.Prepend(nil) + } + c8.clean() + + c9 := c8.PeekContainer(len(testData)) + + c10 := c9.PeekContainer(len(testData) - 1) + c10.Append(testData[len(testData)-1:]) + + compareMany(t, testData, c1.CompileData(), c2.CompileData(), c3.CompileData(), d4, d5, c6.CompileData(), c7.CompileData(), c8.CompileData(), c9.CompileData(), c10.CompileData()) +} + +func compareMany(t *testing.T, reference []byte, other ...[]byte) { + t.Helper() + + for i, cmp := range other { + if !bytes.Equal(reference, cmp) { + t.Errorf("sample %d does not match reference: sample is '%s'", i+1, string(cmp)) + } + } +} + +func TestDataFetching(t *testing.T) { + t.Parallel() + + c1 := New(utils.DuplicateBytes(testData)) + data := c1.GetMax(1) + if string(data[0]) != "T" { + t.Errorf("failed to GetMax(1), got %s, expected %s", string(data), "T") + } + + _, err := c1.Get(1000) + if err == nil { + t.Error("should fail") + } + + _, err = c1.GetAsContainer(1000) + if err == nil { + t.Error("should fail") + } +} + +func TestBlocks(t *testing.T) { + t.Parallel() + + c1 := New(utils.DuplicateBytes(testData)) + c1.PrependLength() + + n, err := c1.GetNextN8() + if err != nil { + t.Errorf("GetNextN8() failed: %s", err) + } + if n != 43 { + t.Errorf("n should be 43, was %d", n) + } + c1.PrependLength() + + n2, err := c1.GetNextN16() + if err != nil { + t.Errorf("GetNextN16() failed: %s", err) + } + if n2 != 43 { + t.Errorf("n should be 43, was %d", n2) + } + c1.PrependLength() + + n3, err := c1.GetNextN32() + if err != nil { + t.Errorf("GetNextN32() failed: %s", err) + } + if n3 != 43 { + t.Errorf("n should be 43, was %d", n3) + } + c1.PrependLength() + + n4, err := c1.GetNextN64() + if err != nil { + t.Errorf("GetNextN64() failed: %s", err) + } + if n4 != 43 { + t.Errorf("n should be 43, was %d", n4) + } +} + +func TestContainerBlockHandling(t *testing.T) { + t.Parallel() + + c1 := New(utils.DuplicateBytes(testData)) + c1.PrependLength() + c1.AppendAsBlock(testData) + c1c := c1.carbonCopy() + + c2 := New(nil) + for i := 0; i < c1.Length(); i++ { + oneByte := make([]byte, 1) + c1c.WriteToSlice(oneByte) + c2.Append(oneByte) + } + + c3 := New(testDataSplitted[0]) + for i := 1; i < len(testDataSplitted); i++ { + c3.Append(testDataSplitted[i]) + } + c3.PrependLength() + + d1, err := c1.GetNextBlock() + if err != nil { + t.Errorf("GetNextBlock failed: %s", err) + } + d2, err := c1.GetNextBlock() + if err != nil { + t.Errorf("GetNextBlock failed: %s", err) + } + d3, err := c2.GetNextBlock() + if err != nil { + t.Errorf("GetNextBlock failed: %s", err) + } + d4, err := c2.GetNextBlock() + if err != nil { + t.Errorf("GetNextBlock failed: %s", err) + } + d5, err := c3.GetNextBlock() + if err != nil { + t.Errorf("GetNextBlock failed: %s", err) + } + + compareMany(t, testData, d1, d2, d3, d4, d5) +} + +func TestContainerMisc(t *testing.T) { + t.Parallel() + + c1 := New() + d1 := c1.CompileData() + if len(d1) > 0 { + t.Fatalf("empty container should not hold any data") + } +} + +func TestDeprecated(t *testing.T) { + t.Parallel() + + NewContainer(utils.DuplicateBytes(testData)) +} diff --git a/base/container/doc.go b/base/container/doc.go new file mode 100644 index 000000000..76cc73cc1 --- /dev/null +++ b/base/container/doc.go @@ -0,0 +1,26 @@ +// Package container gives you a []byte slice on steroids, allowing for quick data appending, prepending and fetching as well as transparent error transportation. +// +// A Container is basically a [][]byte slice that just appends new []byte slices and only copies things around when necessary. +// +// Byte slices added to the Container are not changed or appended, to not corrupt any other data that may be before and after the given slice. +// If interested, consider the following example to understand why this is important: +// +// package main +// +// import ( +// "fmt" +// ) +// +// func main() { +// a := []byte{0, 1,2,3,4,5,6,7,8,9} +// fmt.Printf("a: %+v\n", a) +// fmt.Printf("\nmaking changes...\n(we are not changing a directly)\n\n") +// b := a[2:6] +// c := append(b, 10, 11) +// fmt.Printf("b: %+v\n", b) +// fmt.Printf("c: %+v\n", c) +// fmt.Printf("a: %+v\n", a) +// } +// +// run it here: https://play.golang.org/p/xu1BXT3QYeE +package container diff --git a/base/container/serialization.go b/base/container/serialization.go new file mode 100644 index 000000000..d996c74d1 --- /dev/null +++ b/base/container/serialization.go @@ -0,0 +1,21 @@ +package container + +import ( + "encoding/json" +) + +// MarshalJSON serializes the container as a JSON byte array. +func (c *Container) MarshalJSON() ([]byte, error) { + return json.Marshal(c.CompileData()) +} + +// UnmarshalJSON unserializes a container from a JSON byte array. +func (c *Container) UnmarshalJSON(data []byte) error { + var raw []byte + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + c.compartments = [][]byte{raw} + return nil +} diff --git a/base/database/accessor/accessor-json-bytes.go b/base/database/accessor/accessor-json-bytes.go new file mode 100644 index 000000000..0c2b7c8c5 --- /dev/null +++ b/base/database/accessor/accessor-json-bytes.go @@ -0,0 +1,116 @@ +package accessor + +import ( + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// JSONBytesAccessor is a json string with get functions. +type JSONBytesAccessor struct { + json *[]byte +} + +// NewJSONBytesAccessor adds the Accessor interface to a JSON bytes string. +func NewJSONBytesAccessor(json *[]byte) *JSONBytesAccessor { + return &JSONBytesAccessor{ + json: json, + } +} + +// Set sets the value identified by key. +func (ja *JSONBytesAccessor) Set(key string, value interface{}) error { + result := gjson.GetBytes(*ja.json, key) + if result.Exists() { + err := checkJSONValueType(result, key, value) + if err != nil { + return err + } + } + + newJSON, err := sjson.SetBytes(*ja.json, key, value) + if err != nil { + return err + } + *ja.json = newJSON + return nil +} + +// Get returns the value found by the given json key and whether it could be successfully extracted. +func (ja *JSONBytesAccessor) Get(key string) (value interface{}, ok bool) { + result := gjson.GetBytes(*ja.json, key) + if !result.Exists() { + return nil, false + } + return result.Value(), true +} + +// GetString returns the string found by the given json key and whether it could be successfully extracted. +func (ja *JSONBytesAccessor) GetString(key string) (value string, ok bool) { + result := gjson.GetBytes(*ja.json, key) + if !result.Exists() || result.Type != gjson.String { + return emptyString, false + } + return result.String(), true +} + +// GetStringArray returns the []string found by the given json key and whether it could be successfully extracted. +func (ja *JSONBytesAccessor) GetStringArray(key string) (value []string, ok bool) { + result := gjson.GetBytes(*ja.json, key) + if !result.Exists() && !result.IsArray() { + return nil, false + } + slice := result.Array() + sliceCopy := make([]string, len(slice)) + for i, res := range slice { + if res.Type == gjson.String { + sliceCopy[i] = res.String() + } else { + return nil, false + } + } + return sliceCopy, true +} + +// GetInt returns the int found by the given json key and whether it could be successfully extracted. +func (ja *JSONBytesAccessor) GetInt(key string) (value int64, ok bool) { + result := gjson.GetBytes(*ja.json, key) + if !result.Exists() || result.Type != gjson.Number { + return 0, false + } + return result.Int(), true +} + +// GetFloat returns the float found by the given json key and whether it could be successfully extracted. +func (ja *JSONBytesAccessor) GetFloat(key string) (value float64, ok bool) { + result := gjson.GetBytes(*ja.json, key) + if !result.Exists() || result.Type != gjson.Number { + return 0, false + } + return result.Float(), true +} + +// GetBool returns the bool found by the given json key and whether it could be successfully extracted. +func (ja *JSONBytesAccessor) GetBool(key string) (value bool, ok bool) { + result := gjson.GetBytes(*ja.json, key) + switch { + case !result.Exists(): + return false, false + case result.Type == gjson.True: + return true, true + case result.Type == gjson.False: + return false, true + default: + return false, false + } +} + +// Exists returns the whether the given key exists. +func (ja *JSONBytesAccessor) Exists(key string) bool { + result := gjson.GetBytes(*ja.json, key) + return result.Exists() +} + +// Type returns the accessor type as a string. +func (ja *JSONBytesAccessor) Type() string { + return "JSONBytesAccessor" +} diff --git a/base/database/accessor/accessor-json-string.go b/base/database/accessor/accessor-json-string.go new file mode 100644 index 000000000..0a2767f08 --- /dev/null +++ b/base/database/accessor/accessor-json-string.go @@ -0,0 +1,140 @@ +package accessor + +import ( + "fmt" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// JSONAccessor is a json string with get functions. +type JSONAccessor struct { + json *string +} + +// NewJSONAccessor adds the Accessor interface to a JSON string. +func NewJSONAccessor(json *string) *JSONAccessor { + return &JSONAccessor{ + json: json, + } +} + +// Set sets the value identified by key. +func (ja *JSONAccessor) Set(key string, value interface{}) error { + result := gjson.Get(*ja.json, key) + if result.Exists() { + err := checkJSONValueType(result, key, value) + if err != nil { + return err + } + } + + newJSON, err := sjson.Set(*ja.json, key, value) + if err != nil { + return err + } + *ja.json = newJSON + return nil +} + +func checkJSONValueType(jsonValue gjson.Result, key string, value interface{}) error { + switch value.(type) { + case string: + if jsonValue.Type != gjson.String { + return fmt.Errorf("tried to set field %s (%s) to a %T value", key, jsonValue.Type.String(), value) + } + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64: + if jsonValue.Type != gjson.Number { + return fmt.Errorf("tried to set field %s (%s) to a %T value", key, jsonValue.Type.String(), value) + } + case bool: + if jsonValue.Type != gjson.True && jsonValue.Type != gjson.False { + return fmt.Errorf("tried to set field %s (%s) to a %T value", key, jsonValue.Type.String(), value) + } + case []string: + if !jsonValue.IsArray() { + return fmt.Errorf("tried to set field %s (%s) to a %T value", key, jsonValue.Type.String(), value) + } + } + return nil +} + +// Get returns the value found by the given json key and whether it could be successfully extracted. +func (ja *JSONAccessor) Get(key string) (value interface{}, ok bool) { + result := gjson.Get(*ja.json, key) + if !result.Exists() { + return nil, false + } + return result.Value(), true +} + +// GetString returns the string found by the given json key and whether it could be successfully extracted. +func (ja *JSONAccessor) GetString(key string) (value string, ok bool) { + result := gjson.Get(*ja.json, key) + if !result.Exists() || result.Type != gjson.String { + return emptyString, false + } + return result.String(), true +} + +// GetStringArray returns the []string found by the given json key and whether it could be successfully extracted. +func (ja *JSONAccessor) GetStringArray(key string) (value []string, ok bool) { + result := gjson.Get(*ja.json, key) + if !result.Exists() && !result.IsArray() { + return nil, false + } + slice := result.Array() + sliceCopy := make([]string, len(slice)) + for i, res := range slice { + if res.Type == gjson.String { + sliceCopy[i] = res.String() + } else { + return nil, false + } + } + return sliceCopy, true +} + +// GetInt returns the int found by the given json key and whether it could be successfully extracted. +func (ja *JSONAccessor) GetInt(key string) (value int64, ok bool) { + result := gjson.Get(*ja.json, key) + if !result.Exists() || result.Type != gjson.Number { + return 0, false + } + return result.Int(), true +} + +// GetFloat returns the float found by the given json key and whether it could be successfully extracted. +func (ja *JSONAccessor) GetFloat(key string) (value float64, ok bool) { + result := gjson.Get(*ja.json, key) + if !result.Exists() || result.Type != gjson.Number { + return 0, false + } + return result.Float(), true +} + +// GetBool returns the bool found by the given json key and whether it could be successfully extracted. +func (ja *JSONAccessor) GetBool(key string) (value bool, ok bool) { + result := gjson.Get(*ja.json, key) + switch { + case !result.Exists(): + return false, false + case result.Type == gjson.True: + return true, true + case result.Type == gjson.False: + return false, true + default: + return false, false + } +} + +// Exists returns the whether the given key exists. +func (ja *JSONAccessor) Exists(key string) bool { + result := gjson.Get(*ja.json, key) + return result.Exists() +} + +// Type returns the accessor type as a string. +func (ja *JSONAccessor) Type() string { + return "JSONAccessor" +} diff --git a/base/database/accessor/accessor-struct.go b/base/database/accessor/accessor-struct.go new file mode 100644 index 000000000..97c46a2aa --- /dev/null +++ b/base/database/accessor/accessor-struct.go @@ -0,0 +1,169 @@ +package accessor + +import ( + "errors" + "fmt" + "reflect" +) + +// StructAccessor is a json string with get functions. +type StructAccessor struct { + object reflect.Value +} + +// NewStructAccessor adds the Accessor interface to a JSON string. +func NewStructAccessor(object interface{}) *StructAccessor { + return &StructAccessor{ + object: reflect.ValueOf(object).Elem(), + } +} + +// Set sets the value identified by key. +func (sa *StructAccessor) Set(key string, value interface{}) error { + field := sa.object.FieldByName(key) + if !field.IsValid() { + return errors.New("struct field does not exist") + } + if !field.CanSet() { + return fmt.Errorf("field %s or struct is immutable", field.String()) + } + + newVal := reflect.ValueOf(value) + + // set directly if type matches + if newVal.Kind() == field.Kind() { + field.Set(newVal) + return nil + } + + // handle special cases + switch field.Kind() { // nolint:exhaustive + + // ints + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + var newInt int64 + switch newVal.Kind() { // nolint:exhaustive + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + newInt = newVal.Int() + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + newInt = int64(newVal.Uint()) + default: + return fmt.Errorf("tried to set field %s (%s) to a %s value", key, field.Kind().String(), newVal.Kind().String()) + } + if field.OverflowInt(newInt) { + return fmt.Errorf("setting field %s (%s) to %d would overflow", key, field.Kind().String(), newInt) + } + field.SetInt(newInt) + + // uints + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + var newUint uint64 + switch newVal.Kind() { // nolint:exhaustive + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + newUint = uint64(newVal.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + newUint = newVal.Uint() + default: + return fmt.Errorf("tried to set field %s (%s) to a %s value", key, field.Kind().String(), newVal.Kind().String()) + } + if field.OverflowUint(newUint) { + return fmt.Errorf("setting field %s (%s) to %d would overflow", key, field.Kind().String(), newUint) + } + field.SetUint(newUint) + + // floats + case reflect.Float32, reflect.Float64: + switch newVal.Kind() { // nolint:exhaustive + case reflect.Float32, reflect.Float64: + field.SetFloat(newVal.Float()) + default: + return fmt.Errorf("tried to set field %s (%s) to a %s value", key, field.Kind().String(), newVal.Kind().String()) + } + default: + return fmt.Errorf("tried to set field %s (%s) to a %s value", key, field.Kind().String(), newVal.Kind().String()) + } + + return nil +} + +// Get returns the value found by the given json key and whether it could be successfully extracted. +func (sa *StructAccessor) Get(key string) (value interface{}, ok bool) { + field := sa.object.FieldByName(key) + if !field.IsValid() || !field.CanInterface() { + return nil, false + } + return field.Interface(), true +} + +// GetString returns the string found by the given json key and whether it could be successfully extracted. +func (sa *StructAccessor) GetString(key string) (value string, ok bool) { + field := sa.object.FieldByName(key) + if !field.IsValid() || field.Kind() != reflect.String { + return "", false + } + return field.String(), true +} + +// GetStringArray returns the []string found by the given json key and whether it could be successfully extracted. +func (sa *StructAccessor) GetStringArray(key string) (value []string, ok bool) { + field := sa.object.FieldByName(key) + if !field.IsValid() || field.Kind() != reflect.Slice || !field.CanInterface() { + return nil, false + } + v := field.Interface() + slice, ok := v.([]string) + if !ok { + return nil, false + } + return slice, true +} + +// GetInt returns the int found by the given json key and whether it could be successfully extracted. +func (sa *StructAccessor) GetInt(key string) (value int64, ok bool) { + field := sa.object.FieldByName(key) + if !field.IsValid() { + return 0, false + } + switch field.Kind() { // nolint:exhaustive + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return field.Int(), true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return int64(field.Uint()), true + default: + return 0, false + } +} + +// GetFloat returns the float found by the given json key and whether it could be successfully extracted. +func (sa *StructAccessor) GetFloat(key string) (value float64, ok bool) { + field := sa.object.FieldByName(key) + if !field.IsValid() { + return 0, false + } + switch field.Kind() { // nolint:exhaustive + case reflect.Float32, reflect.Float64: + return field.Float(), true + default: + return 0, false + } +} + +// GetBool returns the bool found by the given json key and whether it could be successfully extracted. +func (sa *StructAccessor) GetBool(key string) (value bool, ok bool) { + field := sa.object.FieldByName(key) + if !field.IsValid() || field.Kind() != reflect.Bool { + return false, false + } + return field.Bool(), true +} + +// Exists returns the whether the given key exists. +func (sa *StructAccessor) Exists(key string) bool { + field := sa.object.FieldByName(key) + return field.IsValid() +} + +// Type returns the accessor type as a string. +func (sa *StructAccessor) Type() string { + return "StructAccessor" +} diff --git a/base/database/accessor/accessor.go b/base/database/accessor/accessor.go new file mode 100644 index 000000000..67a5373f3 --- /dev/null +++ b/base/database/accessor/accessor.go @@ -0,0 +1,18 @@ +package accessor + +const ( + emptyString = "" +) + +// Accessor provides an interface to supply the query matcher a method to retrieve values from an object. +type Accessor interface { + Get(key string) (value interface{}, ok bool) + GetString(key string) (value string, ok bool) + GetStringArray(key string) (value []string, ok bool) + GetInt(key string) (value int64, ok bool) + GetFloat(key string) (value float64, ok bool) + GetBool(key string) (value bool, ok bool) + Exists(key string) bool + Set(key string, value interface{}) error + Type() string +} diff --git a/base/database/accessor/accessor_test.go b/base/database/accessor/accessor_test.go new file mode 100644 index 000000000..69d41bc9a --- /dev/null +++ b/base/database/accessor/accessor_test.go @@ -0,0 +1,291 @@ +//nolint:maligned,unparam +package accessor + +import ( + "encoding/json" + "testing" + + "github.com/safing/portmaster/base/utils" +) + +type TestStruct struct { + S string + A []string + I int + I8 int8 + I16 int16 + I32 int32 + I64 int64 + UI uint + UI8 uint8 + UI16 uint16 + UI32 uint32 + UI64 uint64 + F32 float32 + F64 float64 + B bool +} + +var ( + testStruct = &TestStruct{ + S: "banana", + A: []string{"black", "white"}, + I: 42, + I8: 42, + I16: 42, + I32: 42, + I64: 42, + UI: 42, + UI8: 42, + UI16: 42, + UI32: 42, + UI64: 42, + F32: 42.42, + F64: 42.42, + B: true, + } + testJSONBytes, _ = json.Marshal(testStruct) //nolint:errchkjson + testJSON = string(testJSONBytes) +) + +func testGetString(t *testing.T, acc Accessor, key string, shouldSucceed bool, expectedValue string) { + t.Helper() + + v, ok := acc.GetString(key) + switch { + case !ok && shouldSucceed: + t.Errorf("%s failed to get string with key %s", acc.Type(), key) + case ok && !shouldSucceed: + t.Errorf("%s should have failed to get string with key %s, it returned %v", acc.Type(), key, v) + } + if v != expectedValue { + t.Errorf("%s returned an unexpected value: wanted %v, got %v", acc.Type(), expectedValue, v) + } +} + +func testGetStringArray(t *testing.T, acc Accessor, key string, shouldSucceed bool, expectedValue []string) { + t.Helper() + + v, ok := acc.GetStringArray(key) + switch { + case !ok && shouldSucceed: + t.Errorf("%s failed to get []string with key %s", acc.Type(), key) + case ok && !shouldSucceed: + t.Errorf("%s should have failed to get []string with key %s, it returned %v", acc.Type(), key, v) + } + if !utils.StringSliceEqual(v, expectedValue) { + t.Errorf("%s returned an unexpected value: wanted %v, got %v", acc.Type(), expectedValue, v) + } +} + +func testGetInt(t *testing.T, acc Accessor, key string, shouldSucceed bool, expectedValue int64) { + t.Helper() + + v, ok := acc.GetInt(key) + switch { + case !ok && shouldSucceed: + t.Errorf("%s failed to get int with key %s", acc.Type(), key) + case ok && !shouldSucceed: + t.Errorf("%s should have failed to get int with key %s, it returned %v", acc.Type(), key, v) + } + if v != expectedValue { + t.Errorf("%s returned an unexpected value: wanted %v, got %v", acc.Type(), expectedValue, v) + } +} + +func testGetFloat(t *testing.T, acc Accessor, key string, shouldSucceed bool, expectedValue float64) { + t.Helper() + + v, ok := acc.GetFloat(key) + switch { + case !ok && shouldSucceed: + t.Errorf("%s failed to get float with key %s", acc.Type(), key) + case ok && !shouldSucceed: + t.Errorf("%s should have failed to get float with key %s, it returned %v", acc.Type(), key, v) + } + if int64(v) != int64(expectedValue) { + t.Errorf("%s returned an unexpected value: wanted %v, got %v", acc.Type(), expectedValue, v) + } +} + +func testGetBool(t *testing.T, acc Accessor, key string, shouldSucceed bool, expectedValue bool) { + t.Helper() + + v, ok := acc.GetBool(key) + switch { + case !ok && shouldSucceed: + t.Errorf("%s failed to get bool with key %s", acc.Type(), key) + case ok && !shouldSucceed: + t.Errorf("%s should have failed to get bool with key %s, it returned %v", acc.Type(), key, v) + } + if v != expectedValue { + t.Errorf("%s returned an unexpected value: wanted %v, got %v", acc.Type(), expectedValue, v) + } +} + +func testExists(t *testing.T, acc Accessor, key string, shouldSucceed bool) { + t.Helper() + + ok := acc.Exists(key) + switch { + case !ok && shouldSucceed: + t.Errorf("%s should report key %s as existing", acc.Type(), key) + case ok && !shouldSucceed: + t.Errorf("%s should report key %s as non-existing", acc.Type(), key) + } +} + +func testSet(t *testing.T, acc Accessor, key string, shouldSucceed bool, valueToSet interface{}) { + t.Helper() + + err := acc.Set(key, valueToSet) + switch { + case err != nil && shouldSucceed: + t.Errorf("%s failed to set %s to %+v: %s", acc.Type(), key, valueToSet, err) + case err == nil && !shouldSucceed: + t.Errorf("%s should have failed to set %s to %+v", acc.Type(), key, valueToSet) + } +} + +func TestAccessor(t *testing.T) { + t.Parallel() + + // Test interface compliance. + accs := []Accessor{ + NewJSONAccessor(&testJSON), + NewJSONBytesAccessor(&testJSONBytes), + NewStructAccessor(testStruct), + } + + // get + for _, acc := range accs { + testGetString(t, acc, "S", true, "banana") + testGetStringArray(t, acc, "A", true, []string{"black", "white"}) + testGetInt(t, acc, "I", true, 42) + testGetInt(t, acc, "I8", true, 42) + testGetInt(t, acc, "I16", true, 42) + testGetInt(t, acc, "I32", true, 42) + testGetInt(t, acc, "I64", true, 42) + testGetInt(t, acc, "UI", true, 42) + testGetInt(t, acc, "UI8", true, 42) + testGetInt(t, acc, "UI16", true, 42) + testGetInt(t, acc, "UI32", true, 42) + testGetInt(t, acc, "UI64", true, 42) + testGetFloat(t, acc, "F32", true, 42.42) + testGetFloat(t, acc, "F64", true, 42.42) + testGetBool(t, acc, "B", true, true) + } + + // set + for _, acc := range accs { + testSet(t, acc, "S", true, "coconut") + testSet(t, acc, "A", true, []string{"green", "blue"}) + testSet(t, acc, "I", true, uint32(44)) + testSet(t, acc, "I8", true, uint64(44)) + testSet(t, acc, "I16", true, uint8(44)) + testSet(t, acc, "I32", true, uint16(44)) + testSet(t, acc, "I64", true, 44) + testSet(t, acc, "UI", true, 44) + testSet(t, acc, "UI8", true, int64(44)) + testSet(t, acc, "UI16", true, int32(44)) + testSet(t, acc, "UI32", true, int8(44)) + testSet(t, acc, "UI64", true, int16(44)) + testSet(t, acc, "F32", true, 44.44) + testSet(t, acc, "F64", true, 44.44) + testSet(t, acc, "B", true, false) + } + + // get again to check if new values were set + for _, acc := range accs { + testGetString(t, acc, "S", true, "coconut") + testGetStringArray(t, acc, "A", true, []string{"green", "blue"}) + testGetInt(t, acc, "I", true, 44) + testGetInt(t, acc, "I8", true, 44) + testGetInt(t, acc, "I16", true, 44) + testGetInt(t, acc, "I32", true, 44) + testGetInt(t, acc, "I64", true, 44) + testGetInt(t, acc, "UI", true, 44) + testGetInt(t, acc, "UI8", true, 44) + testGetInt(t, acc, "UI16", true, 44) + testGetInt(t, acc, "UI32", true, 44) + testGetInt(t, acc, "UI64", true, 44) + testGetFloat(t, acc, "F32", true, 44.44) + testGetFloat(t, acc, "F64", true, 44.44) + testGetBool(t, acc, "B", true, false) + } + + // failures + for _, acc := range accs { + testSet(t, acc, "S", false, true) + testSet(t, acc, "S", false, false) + testSet(t, acc, "S", false, 1) + testSet(t, acc, "S", false, 1.1) + + testSet(t, acc, "A", false, "1") + testSet(t, acc, "A", false, true) + testSet(t, acc, "A", false, false) + testSet(t, acc, "A", false, 1) + testSet(t, acc, "A", false, 1.1) + + testSet(t, acc, "I", false, "1") + testSet(t, acc, "I8", false, "1") + testSet(t, acc, "I16", false, "1") + testSet(t, acc, "I32", false, "1") + testSet(t, acc, "I64", false, "1") + testSet(t, acc, "UI", false, "1") + testSet(t, acc, "UI8", false, "1") + testSet(t, acc, "UI16", false, "1") + testSet(t, acc, "UI32", false, "1") + testSet(t, acc, "UI64", false, "1") + + testSet(t, acc, "F32", false, "1.1") + testSet(t, acc, "F64", false, "1.1") + + testSet(t, acc, "B", false, "false") + testSet(t, acc, "B", false, 1) + testSet(t, acc, "B", false, 1.1) + } + + // get again to check if values werent changed when an error occurred + for _, acc := range accs { + testGetString(t, acc, "S", true, "coconut") + testGetStringArray(t, acc, "A", true, []string{"green", "blue"}) + testGetInt(t, acc, "I", true, 44) + testGetInt(t, acc, "I8", true, 44) + testGetInt(t, acc, "I16", true, 44) + testGetInt(t, acc, "I32", true, 44) + testGetInt(t, acc, "I64", true, 44) + testGetInt(t, acc, "UI", true, 44) + testGetInt(t, acc, "UI8", true, 44) + testGetInt(t, acc, "UI16", true, 44) + testGetInt(t, acc, "UI32", true, 44) + testGetInt(t, acc, "UI64", true, 44) + testGetFloat(t, acc, "F32", true, 44.44) + testGetFloat(t, acc, "F64", true, 44.44) + testGetBool(t, acc, "B", true, false) + } + + // test existence + for _, acc := range accs { + testExists(t, acc, "S", true) + testExists(t, acc, "A", true) + testExists(t, acc, "I", true) + testExists(t, acc, "I8", true) + testExists(t, acc, "I16", true) + testExists(t, acc, "I32", true) + testExists(t, acc, "I64", true) + testExists(t, acc, "UI", true) + testExists(t, acc, "UI8", true) + testExists(t, acc, "UI16", true) + testExists(t, acc, "UI32", true) + testExists(t, acc, "UI64", true) + testExists(t, acc, "F32", true) + testExists(t, acc, "F64", true) + testExists(t, acc, "B", true) + } + + // test non-existence + for _, acc := range accs { + testExists(t, acc, "X", false) + } +} diff --git a/base/database/boilerplate_test.go b/base/database/boilerplate_test.go new file mode 100644 index 000000000..d6da0d322 --- /dev/null +++ b/base/database/boilerplate_test.go @@ -0,0 +1,65 @@ +package database + +import ( + "fmt" + "sync" + + "github.com/safing/portmaster/base/database/record" +) + +type Example struct { + record.Base + sync.Mutex + + Name string + Score int +} + +var exampleDB = NewInterface(&Options{ + Internal: true, + Local: true, +}) + +// GetExample gets an Example from the database. +func GetExample(key string) (*Example, error) { + r, err := exampleDB.Get(key) + if err != nil { + return nil, err + } + + // unwrap + if r.IsWrapped() { + // only allocate a new struct, if we need it + newExample := &Example{} + err = record.Unwrap(r, newExample) + if err != nil { + return nil, err + } + return newExample, nil + } + + // or adjust type + newExample, ok := r.(*Example) + if !ok { + return nil, fmt.Errorf("record not of type *Example, but %T", r) + } + return newExample, nil +} + +func (e *Example) Save() error { + return exampleDB.Put(e) +} + +func (e *Example) SaveAs(key string) error { + e.SetKey(key) + return exampleDB.PutNew(e) +} + +func NewExample(key, name string, score int) *Example { + newExample := &Example{ + Name: name, + Score: score, + } + newExample.SetKey(key) + return newExample +} diff --git a/base/database/controller.go b/base/database/controller.go new file mode 100644 index 000000000..4d95c01e4 --- /dev/null +++ b/base/database/controller.go @@ -0,0 +1,355 @@ +package database + +import ( + "context" + "errors" + "sync" + "time" + + "github.com/safing/portmaster/base/database/iterator" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/database/storage" +) + +// A Controller takes care of all the extra database logic. +type Controller struct { + database *Database + storage storage.Interface + shadowDelete bool + + hooksLock sync.RWMutex + hooks []*RegisteredHook + + subscriptionLock sync.RWMutex + subscriptions []*Subscription +} + +// newController creates a new controller for a storage. +func newController(database *Database, storageInt storage.Interface, shadowDelete bool) *Controller { + return &Controller{ + database: database, + storage: storageInt, + shadowDelete: shadowDelete, + } +} + +// ReadOnly returns whether the storage is read only. +func (c *Controller) ReadOnly() bool { + return c.storage.ReadOnly() +} + +// Injected returns whether the storage is injected. +func (c *Controller) Injected() bool { + return c.storage.Injected() +} + +// Get returns the record with the given key. +func (c *Controller) Get(key string) (record.Record, error) { + if shuttingDown.IsSet() { + return nil, ErrShuttingDown + } + + if err := c.runPreGetHooks(key); err != nil { + return nil, err + } + + r, err := c.storage.Get(key) + if err != nil { + // replace not found error + if errors.Is(err, storage.ErrNotFound) { + return nil, ErrNotFound + } + return nil, err + } + + r.Lock() + defer r.Unlock() + + r, err = c.runPostGetHooks(r) + if err != nil { + return nil, err + } + + if !r.Meta().CheckValidity() { + return nil, ErrNotFound + } + + return r, nil +} + +// GetMeta returns the metadata of the record with the given key. +func (c *Controller) GetMeta(key string) (*record.Meta, error) { + if shuttingDown.IsSet() { + return nil, ErrShuttingDown + } + + var m *record.Meta + var err error + if metaDB, ok := c.storage.(storage.MetaHandler); ok { + m, err = metaDB.GetMeta(key) + if err != nil { + // replace not found error + if errors.Is(err, storage.ErrNotFound) { + return nil, ErrNotFound + } + return nil, err + } + } else { + r, err := c.storage.Get(key) + if err != nil { + // replace not found error + if errors.Is(err, storage.ErrNotFound) { + return nil, ErrNotFound + } + return nil, err + } + m = r.Meta() + } + + if !m.CheckValidity() { + return nil, ErrNotFound + } + + return m, nil +} + +// Put saves a record in the database, executes any registered +// pre-put hooks and finally send an update to all subscribers. +// The record must be locked and secured from concurrent access +// when calling Put(). +func (c *Controller) Put(r record.Record) (err error) { + if shuttingDown.IsSet() { + return ErrShuttingDown + } + + if c.ReadOnly() { + return ErrReadOnly + } + + r, err = c.runPrePutHooks(r) + if err != nil { + return err + } + + if !c.shadowDelete && r.Meta().IsDeleted() { + // Immediate delete. + err = c.storage.Delete(r.DatabaseKey()) + } else { + // Put or shadow delete. + r, err = c.storage.Put(r) + } + + if err != nil { + return err + } + + if r == nil { + return errors.New("storage returned nil record after successful put operation") + } + + c.notifySubscribers(r) + + return nil +} + +// PutMany stores many records in the database. It does not +// process any hooks or update subscriptions. Use with care! +func (c *Controller) PutMany() (chan<- record.Record, <-chan error) { + if shuttingDown.IsSet() { + errs := make(chan error, 1) + errs <- ErrShuttingDown + return make(chan record.Record), errs + } + + if c.ReadOnly() { + errs := make(chan error, 1) + errs <- ErrReadOnly + return make(chan record.Record), errs + } + + if batcher, ok := c.storage.(storage.Batcher); ok { + return batcher.PutMany(c.shadowDelete) + } + + errs := make(chan error, 1) + errs <- ErrNotImplemented + return make(chan record.Record), errs +} + +// Query executes the given query on the database. +func (c *Controller) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { + if shuttingDown.IsSet() { + return nil, ErrShuttingDown + } + + it, err := c.storage.Query(q, local, internal) + if err != nil { + return nil, err + } + + return it, nil +} + +// PushUpdate pushes a record update to subscribers. +// The caller must hold the record's lock when calling +// PushUpdate. +func (c *Controller) PushUpdate(r record.Record) { + if c != nil { + if shuttingDown.IsSet() { + return + } + + c.notifySubscribers(r) + } +} + +func (c *Controller) addSubscription(sub *Subscription) { + if shuttingDown.IsSet() { + return + } + + c.subscriptionLock.Lock() + defer c.subscriptionLock.Unlock() + + c.subscriptions = append(c.subscriptions, sub) +} + +// Maintain runs the Maintain method on the storage. +func (c *Controller) Maintain(ctx context.Context) error { + if shuttingDown.IsSet() { + return ErrShuttingDown + } + + if maintainer, ok := c.storage.(storage.Maintainer); ok { + return maintainer.Maintain(ctx) + } + return nil +} + +// MaintainThorough runs the MaintainThorough method on the +// storage. +func (c *Controller) MaintainThorough(ctx context.Context) error { + if shuttingDown.IsSet() { + return ErrShuttingDown + } + + if maintainer, ok := c.storage.(storage.Maintainer); ok { + return maintainer.MaintainThorough(ctx) + } + return nil +} + +// MaintainRecordStates runs the record state lifecycle +// maintenance on the storage. +func (c *Controller) MaintainRecordStates(ctx context.Context, purgeDeletedBefore time.Time) error { + if shuttingDown.IsSet() { + return ErrShuttingDown + } + + return c.storage.MaintainRecordStates(ctx, purgeDeletedBefore, c.shadowDelete) +} + +// Purge deletes all records that match the given query. +// It returns the number of successful deletes and an error. +func (c *Controller) Purge(ctx context.Context, q *query.Query, local, internal bool) (int, error) { + if shuttingDown.IsSet() { + return 0, ErrShuttingDown + } + + if purger, ok := c.storage.(storage.Purger); ok { + return purger.Purge(ctx, q, local, internal, c.shadowDelete) + } + + return 0, ErrNotImplemented +} + +// Shutdown shuts down the storage. +func (c *Controller) Shutdown() error { + return c.storage.Shutdown() +} + +// notifySubscribers notifies all subscribers that are interested +// in r. r must be locked when calling notifySubscribers. +// Any subscriber that is not blocking on it's feed channel will +// be skipped. +func (c *Controller) notifySubscribers(r record.Record) { + c.subscriptionLock.RLock() + defer c.subscriptionLock.RUnlock() + + for _, sub := range c.subscriptions { + if r.Meta().CheckPermission(sub.local, sub.internal) && sub.q.Matches(r) { + select { + case sub.Feed <- r: + default: + } + } + } +} + +func (c *Controller) runPreGetHooks(key string) error { + c.hooksLock.RLock() + defer c.hooksLock.RUnlock() + + for _, hook := range c.hooks { + if !hook.h.UsesPreGet() { + continue + } + + if !hook.q.MatchesKey(key) { + continue + } + + if err := hook.h.PreGet(key); err != nil { + return err + } + } + + return nil +} + +func (c *Controller) runPostGetHooks(r record.Record) (record.Record, error) { + c.hooksLock.RLock() + defer c.hooksLock.RUnlock() + + var err error + for _, hook := range c.hooks { + if !hook.h.UsesPostGet() { + continue + } + + if !hook.q.Matches(r) { + continue + } + + r, err = hook.h.PostGet(r) + if err != nil { + return nil, err + } + } + + return r, nil +} + +func (c *Controller) runPrePutHooks(r record.Record) (record.Record, error) { + c.hooksLock.RLock() + defer c.hooksLock.RUnlock() + + var err error + for _, hook := range c.hooks { + if !hook.h.UsesPrePut() { + continue + } + + if !hook.q.Matches(r) { + continue + } + + r, err = hook.h.PrePut(r) + if err != nil { + return nil, err + } + } + + return r, nil +} diff --git a/base/database/controllers.go b/base/database/controllers.go new file mode 100644 index 000000000..954807d1f --- /dev/null +++ b/base/database/controllers.go @@ -0,0 +1,106 @@ +package database + +import ( + "errors" + "fmt" + "sync" + + "github.com/safing/portmaster/base/database/storage" +) + +// StorageTypeInjected is the type of injected databases. +const StorageTypeInjected = "injected" + +var ( + controllers = make(map[string]*Controller) + controllersLock sync.RWMutex +) + +func getController(name string) (*Controller, error) { + if !initialized.IsSet() { + return nil, errors.New("database not initialized") + } + + // return database if already started + controllersLock.RLock() + controller, ok := controllers[name] + controllersLock.RUnlock() + if ok { + return controller, nil + } + + controllersLock.Lock() + defer controllersLock.Unlock() + + if shuttingDown.IsSet() { + return nil, ErrShuttingDown + } + + // get db registration + registeredDB, err := getDatabase(name) + if err != nil { + return nil, fmt.Errorf("could not start database %s: %w", name, err) + } + + // Check if database is injected. + if registeredDB.StorageType == StorageTypeInjected { + return nil, fmt.Errorf("database storage is not injected") + } + + // get location + dbLocation, err := getLocation(name, registeredDB.StorageType) + if err != nil { + return nil, fmt.Errorf("could not start database %s (type %s): %w", name, registeredDB.StorageType, err) + } + + // start database + storageInt, err := storage.StartDatabase(name, registeredDB.StorageType, dbLocation) + if err != nil { + return nil, fmt.Errorf("could not start database %s (type %s): %w", name, registeredDB.StorageType, err) + } + + controller = newController(registeredDB, storageInt, registeredDB.ShadowDelete) + controllers[name] = controller + return controller, nil +} + +// InjectDatabase injects an already running database into the system. +func InjectDatabase(name string, storageInt storage.Interface) (*Controller, error) { + controllersLock.Lock() + defer controllersLock.Unlock() + + if shuttingDown.IsSet() { + return nil, ErrShuttingDown + } + + _, ok := controllers[name] + if ok { + return nil, fmt.Errorf(`database "%s" already loaded`, name) + } + + registryLock.Lock() + defer registryLock.Unlock() + + // check if database is registered + registeredDB, ok := registry[name] + if !ok { + return nil, fmt.Errorf("database %q not registered", name) + } + if registeredDB.StorageType != StorageTypeInjected { + return nil, fmt.Errorf("database not of type %q", StorageTypeInjected) + } + + controller := newController(registeredDB, storageInt, false) + controllers[name] = controller + return controller, nil +} + +// Withdraw withdraws an injected database, but leaves the database registered. +func (c *Controller) Withdraw() { + if c != nil && c.Injected() { + controllersLock.Lock() + defer controllersLock.Unlock() + + delete(controllers, c.database.Name) + } +} diff --git a/base/database/database.go b/base/database/database.go new file mode 100644 index 000000000..332c3f81d --- /dev/null +++ b/base/database/database.go @@ -0,0 +1,26 @@ +package database + +import ( + "time" +) + +// Database holds information about a registered database. +type Database struct { + Name string + Description string + StorageType string + ShadowDelete bool // Whether deleted records should be kept until purged. + Registered time.Time + LastUpdated time.Time + LastLoaded time.Time +} + +// Loaded updates the LastLoaded timestamp. +func (db *Database) Loaded() { + db.LastLoaded = time.Now().Round(time.Second) +} + +// Updated updates the LastUpdated timestamp. +func (db *Database) Updated() { + db.LastUpdated = time.Now().Round(time.Second) +} diff --git a/base/database/database_test.go b/base/database/database_test.go new file mode 100644 index 000000000..03dcc66a8 --- /dev/null +++ b/base/database/database_test.go @@ -0,0 +1,303 @@ +package database + +import ( + "context" + "errors" + "fmt" + "log" + "os" + "reflect" + "runtime/pprof" + "testing" + "time" + + q "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/database/storage" + _ "github.com/safing/portmaster/base/database/storage/badger" + _ "github.com/safing/portmaster/base/database/storage/bbolt" + _ "github.com/safing/portmaster/base/database/storage/fstree" + _ "github.com/safing/portmaster/base/database/storage/hashmap" +) + +func TestMain(m *testing.M) { + testDir, err := os.MkdirTemp("", "portbase-database-testing-") + if err != nil { + panic(err) + } + + err = InitializeWithPath(testDir) + if err != nil { + panic(err) + } + + exitCode := m.Run() + + // Clean up the test directory. + // Do not defer, as we end this function with a os.Exit call. + _ = os.RemoveAll(testDir) + + os.Exit(exitCode) +} + +func makeKey(dbName, key string) string { + return fmt.Sprintf("%s:%s", dbName, key) +} + +func testDatabase(t *testing.T, storageType string, shadowDelete bool) { //nolint:maintidx,thelper + t.Run(fmt.Sprintf("TestStorage_%s_%v", storageType, shadowDelete), func(t *testing.T) { + dbName := fmt.Sprintf("testing-%s-%v", storageType, shadowDelete) + fmt.Println(dbName) + _, err := Register(&Database{ + Name: dbName, + Description: fmt.Sprintf("Unit Test Database for %s", storageType), + StorageType: storageType, + ShadowDelete: shadowDelete, + }) + if err != nil { + t.Fatal(err) + } + dbController, err := getController(dbName) + if err != nil { + t.Fatal(err) + } + + // hook + hook, err := RegisterHook(q.New(dbName).MustBeValid(), &HookBase{}) + if err != nil { + t.Fatal(err) + } + + // interface + db := NewInterface(&Options{ + Local: true, + Internal: true, + }) + + // sub + sub, err := db.Subscribe(q.New(dbName).MustBeValid()) + if err != nil { + t.Fatal(err) + } + + A := NewExample(dbName+":A", "Herbert", 411) + err = A.Save() + if err != nil { + t.Fatal(err) + } + + B := NewExample(makeKey(dbName, "B"), "Fritz", 347) + err = B.Save() + if err != nil { + t.Fatal(err) + } + + C := NewExample(makeKey(dbName, "C"), "Norbert", 217) + err = C.Save() + if err != nil { + t.Fatal(err) + } + + exists, err := db.Exists(makeKey(dbName, "A")) + if err != nil { + t.Fatal(err) + } + if !exists { + t.Fatalf("record %s should exist!", makeKey(dbName, "A")) + } + + A1, err := GetExample(makeKey(dbName, "A")) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(A, A1) { + log.Fatalf("A and A1 mismatch, A1: %v", A1) + } + + cnt := countRecords(t, db, q.New(dbName).Where( + q.And( + q.Where("Name", q.EndsWith, "bert"), + q.Where("Score", q.GreaterThan, 100), + ), + )) + if cnt != 2 { + t.Fatalf("expected two records, got %d", cnt) + } + + // test putmany + if _, ok := dbController.storage.(storage.Batcher); ok { + batchPut := db.PutMany(dbName) + records := []record.Record{A, B, C, nil} // nil is to signify finish + for _, r := range records { + err = batchPut(r) + if err != nil { + t.Fatal(err) + } + } + } + + // test maintenance + if _, ok := dbController.storage.(storage.Maintainer); ok { + now := time.Now().UTC() + nowUnix := now.Unix() + + // we start with 3 records without expiry + cnt := countRecords(t, db, q.New(dbName)) + if cnt != 3 { + t.Fatalf("expected three records, got %d", cnt) + } + // delete entry + A.Meta().Deleted = nowUnix - 61 + err = A.Save() + if err != nil { + t.Fatal(err) + } + // expire entry + B.Meta().Expires = nowUnix - 1 + err = B.Save() + if err != nil { + t.Fatal(err) + } + + // one left + cnt = countRecords(t, db, q.New(dbName)) + if cnt != 1 { + t.Fatalf("expected one record, got %d", cnt) + } + + // run maintenance + err = dbController.MaintainRecordStates(context.TODO(), now.Add(-60*time.Second)) + if err != nil { + t.Fatal(err) + } + // one left + cnt = countRecords(t, db, q.New(dbName)) + if cnt != 1 { + t.Fatalf("expected one record, got %d", cnt) + } + + // check status individually + _, err = dbController.storage.Get("A") + if !errors.Is(err, storage.ErrNotFound) { + t.Errorf("A should be deleted and purged, err=%s", err) + } + B1, err := dbController.storage.Get("B") + if err != nil { + t.Fatalf("should exist: %s, original meta: %+v", err, B.Meta()) + } + if B1.Meta().Deleted == 0 { + t.Errorf("B should be deleted") + } + + // delete last entry + C.Meta().Deleted = nowUnix - 1 + err = C.Save() + if err != nil { + t.Fatal(err) + } + + // run maintenance + err = dbController.MaintainRecordStates(context.TODO(), now) + if err != nil { + t.Fatal(err) + } + + // check status individually + B2, err := dbController.storage.Get("B") + if err == nil { + t.Errorf("B should be deleted and purged, meta: %+v", B2.Meta()) + } else if !errors.Is(err, storage.ErrNotFound) { + t.Errorf("B should be deleted and purged, err=%s", err) + } + C2, err := dbController.storage.Get("C") + if err == nil { + t.Errorf("C should be deleted and purged, meta: %+v", C2.Meta()) + } else if !errors.Is(err, storage.ErrNotFound) { + t.Errorf("C should be deleted and purged, err=%s", err) + } + + // none left + cnt = countRecords(t, db, q.New(dbName)) + if cnt != 0 { + t.Fatalf("expected no records, got %d", cnt) + } + } + + err = hook.Cancel() + if err != nil { + t.Fatal(err) + } + err = sub.Cancel() + if err != nil { + t.Fatal(err) + } + }) +} + +func TestDatabaseSystem(t *testing.T) { //nolint:tparallel + t.Parallel() + + // panic after 10 seconds, to check for locks + finished := make(chan struct{}) + defer close(finished) + go func() { + select { + case <-finished: + case <-time.After(10 * time.Second): + fmt.Println("===== TAKING TOO LONG - PRINTING STACK TRACES =====") + _ = pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) + os.Exit(1) + } + }() + + for _, shadowDelete := range []bool{false, true} { + testDatabase(t, "bbolt", shadowDelete) + testDatabase(t, "hashmap", shadowDelete) + testDatabase(t, "fstree", shadowDelete) + // testDatabase(t, "badger", shadowDelete) + // TODO: Fix badger tests + } + + err := MaintainRecordStates(context.TODO()) + if err != nil { + t.Fatal(err) + } + + err = Maintain(context.TODO()) + if err != nil { + t.Fatal(err) + } + + err = MaintainThorough(context.TODO()) + if err != nil { + t.Fatal(err) + } + + err = Shutdown() + if err != nil { + t.Fatal(err) + } +} + +func countRecords(t *testing.T, db *Interface, query *q.Query) int { + t.Helper() + + _, err := query.Check() + if err != nil { + t.Fatal(err) + } + + it, err := db.Query(query) + if err != nil { + t.Fatal(err) + } + + cnt := 0 + for range it.Next { + cnt++ + } + if it.Err() != nil { + t.Fatal(it.Err()) + } + return cnt +} diff --git a/base/database/dbmodule/db.go b/base/database/dbmodule/db.go new file mode 100644 index 000000000..23eecf1f8 --- /dev/null +++ b/base/database/dbmodule/db.go @@ -0,0 +1,50 @@ +package dbmodule + +import ( + "errors" + + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/dataroot" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/utils" +) + +var ( + databaseStructureRoot *utils.DirStructure + + module *modules.Module +) + +func init() { + module = modules.Register("database", prep, start, stop) +} + +// SetDatabaseLocation sets the location of the database for initialization. Supply either a path or dir structure. +func SetDatabaseLocation(dirStructureRoot *utils.DirStructure) { + if databaseStructureRoot == nil { + databaseStructureRoot = dirStructureRoot + } +} + +func prep() error { + SetDatabaseLocation(dataroot.Root()) + if databaseStructureRoot == nil { + return errors.New("database location not specified") + } + + return nil +} + +func start() error { + err := database.Initialize(databaseStructureRoot) + if err != nil { + return err + } + + startMaintenanceTasks() + return nil +} + +func stop() error { + return database.Shutdown() +} diff --git a/base/database/dbmodule/maintenance.go b/base/database/dbmodule/maintenance.go new file mode 100644 index 000000000..3326ebaf4 --- /dev/null +++ b/base/database/dbmodule/maintenance.go @@ -0,0 +1,31 @@ +package dbmodule + +import ( + "context" + "time" + + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" +) + +func startMaintenanceTasks() { + module.NewTask("basic maintenance", maintainBasic).Repeat(10 * time.Minute).MaxDelay(10 * time.Minute) + module.NewTask("thorough maintenance", maintainThorough).Repeat(1 * time.Hour).MaxDelay(1 * time.Hour) + module.NewTask("record maintenance", maintainRecords).Repeat(1 * time.Hour).MaxDelay(1 * time.Hour) +} + +func maintainBasic(ctx context.Context, task *modules.Task) error { + log.Infof("database: running Maintain") + return database.Maintain(ctx) +} + +func maintainThorough(ctx context.Context, task *modules.Task) error { + log.Infof("database: running MaintainThorough") + return database.MaintainThorough(ctx) +} + +func maintainRecords(ctx context.Context, task *modules.Task) error { + log.Infof("database: running MaintainRecordStates") + return database.MaintainRecordStates(ctx) +} diff --git a/base/database/doc.go b/base/database/doc.go new file mode 100644 index 000000000..1e1e6a5ff --- /dev/null +++ b/base/database/doc.go @@ -0,0 +1,62 @@ +/* +Package database provides a universal interface for interacting with the database. + +# A Lazy Database + +The database system can handle Go structs as well as serialized data by the dsd package. +While data is in transit within the system, it does not know which form it currently has. Only when it reaches its destination, it must ensure that it is either of a certain type or dump it. + +# Record Interface + +The database system uses the Record interface to transparently handle all types of structs that get saved in the database. Structs include the Base struct to fulfill most parts of the Record interface. + +Boilerplate Code: + + type Example struct { + record.Base + sync.Mutex + + Name string + Score int + } + + var ( + db = database.NewInterface(nil) + ) + + // GetExample gets an Example from the database. + func GetExample(key string) (*Example, error) { + r, err := db.Get(key) + if err != nil { + return nil, err + } + + // unwrap + if r.IsWrapped() { + // only allocate a new struct, if we need it + new := &Example{} + err = record.Unwrap(r, new) + if err != nil { + return nil, err + } + return new, nil + } + + // or adjust type + new, ok := r.(*Example) + if !ok { + return nil, fmt.Errorf("record not of type *Example, but %T", r) + } + return new, nil + } + + func (e *Example) Save() error { + return db.Put(e) + } + + func (e *Example) SaveAs(key string) error { + e.SetKey(key) + return db.PutNew(e) + } +*/ +package database diff --git a/base/database/errors.go b/base/database/errors.go new file mode 100644 index 000000000..425b0f319 --- /dev/null +++ b/base/database/errors.go @@ -0,0 +1,14 @@ +package database + +import ( + "errors" +) + +// Errors. +var ( + ErrNotFound = errors.New("database entry not found") + ErrPermissionDenied = errors.New("access to database record denied") + ErrReadOnly = errors.New("database is read only") + ErrShuttingDown = errors.New("database system is shutting down") + ErrNotImplemented = errors.New("not implemented by this storage") +) diff --git a/base/database/hook.go b/base/database/hook.go new file mode 100644 index 000000000..2115c94c6 --- /dev/null +++ b/base/database/hook.go @@ -0,0 +1,91 @@ +package database + +import ( + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" +) + +// Hook can be registered for a database query and +// will be executed at certain points during the life +// cycle of a database record. +type Hook interface { + // UsesPreGet should return true if the hook's PreGet + // should be called prior to loading a database record + // from the underlying storage. + UsesPreGet() bool + // PreGet is called before a database record is loaded from + // the underlying storage. A PreGet hookd may be used to + // implement more advanced access control on database keys. + PreGet(dbKey string) error + // UsesPostGet should return true if the hook's PostGet + // should be called after loading a database record from + // the underlying storage. + UsesPostGet() bool + // PostGet is called after a record has been loaded form the + // underlying storage and may perform additional mutation + // or access check based on the records data. + // The passed record is already locked by the database system + // so users can safely access all data of r. + PostGet(r record.Record) (record.Record, error) + // UsesPrePut should return true if the hook's PrePut method + // should be called prior to saving a record in the database. + UsesPrePut() bool + // PrePut is called prior to saving (creating or updating) a + // record in the database storage. It may be used to perform + // extended validation or mutations on the record. + // The passed record is already locked by the database system + // so users can safely access all data of r. + PrePut(r record.Record) (record.Record, error) +} + +// RegisteredHook is a registered database hook. +type RegisteredHook struct { + q *query.Query + h Hook +} + +// RegisterHook registers a hook for records matching the given +// query in the database. +func RegisterHook(q *query.Query, hook Hook) (*RegisteredHook, error) { + _, err := q.Check() + if err != nil { + return nil, err + } + + c, err := getController(q.DatabaseName()) + if err != nil { + return nil, err + } + + rh := &RegisteredHook{ + q: q, + h: hook, + } + + c.hooksLock.Lock() + defer c.hooksLock.Unlock() + c.hooks = append(c.hooks, rh) + + return rh, nil +} + +// Cancel unregisteres the hook from the database. Once +// Cancel returned the hook's methods will not be called +// anymore for updates that matched the registered query. +func (h *RegisteredHook) Cancel() error { + c, err := getController(h.q.DatabaseName()) + if err != nil { + return err + } + + c.hooksLock.Lock() + defer c.hooksLock.Unlock() + + for key, hook := range c.hooks { + if hook.q == h.q { + c.hooks = append(c.hooks[:key], c.hooks[key+1:]...) + return nil + } + } + return nil +} diff --git a/base/database/hookbase.go b/base/database/hookbase.go new file mode 100644 index 000000000..5d11f34c5 --- /dev/null +++ b/base/database/hookbase.go @@ -0,0 +1,38 @@ +package database + +import ( + "github.com/safing/portmaster/base/database/record" +) + +// HookBase implements the Hook interface and provides dummy functions to reduce boilerplate. +type HookBase struct{} + +// UsesPreGet implements the Hook interface and returns false. +func (b *HookBase) UsesPreGet() bool { + return false +} + +// UsesPostGet implements the Hook interface and returns false. +func (b *HookBase) UsesPostGet() bool { + return false +} + +// UsesPrePut implements the Hook interface and returns false. +func (b *HookBase) UsesPrePut() bool { + return false +} + +// PreGet implements the Hook interface. +func (b *HookBase) PreGet(dbKey string) error { + return nil +} + +// PostGet implements the Hook interface. +func (b *HookBase) PostGet(r record.Record) (record.Record, error) { + return r, nil +} + +// PrePut implements the Hook interface. +func (b *HookBase) PrePut(r record.Record) (record.Record, error) { + return r, nil +} diff --git a/base/database/interface.go b/base/database/interface.go new file mode 100644 index 000000000..ce9b8a973 --- /dev/null +++ b/base/database/interface.go @@ -0,0 +1,585 @@ +package database + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/bluele/gcache" + "github.com/tevino/abool" + + "github.com/safing/portmaster/base/database/accessor" + "github.com/safing/portmaster/base/database/iterator" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" +) + +const ( + getDBFromKey = "" +) + +// Interface provides a method to access the database with attached options. +type Interface struct { + options *Options + cache gcache.Cache + + writeCache map[string]record.Record + writeCacheLock sync.Mutex + triggerCacheWrite chan struct{} +} + +// Options holds options that may be set for an Interface instance. +type Options struct { + // Local specifies if the interface is used by an actor on the local device. + // Setting both the Local and Internal flags will bring performance + // improvements because less checks are needed. + Local bool + + // Internal specifies if the interface is used by an actor within the + // software. Setting both the Local and Internal flags will bring performance + // improvements because less checks are needed. + Internal bool + + // AlwaysMakeSecret will have the interface mark all saved records as secret. + // This means that they will be only accessible by an internal interface. + AlwaysMakeSecret bool + + // AlwaysMakeCrownjewel will have the interface mark all saved records as + // crown jewels. This means that they will be only accessible by a local + // interface. + AlwaysMakeCrownjewel bool + + // AlwaysSetRelativateExpiry will have the interface set a relative expiry, + // based on the current time, on all saved records. + AlwaysSetRelativateExpiry int64 + + // AlwaysSetAbsoluteExpiry will have the interface set an absolute expiry on + // all saved records. + AlwaysSetAbsoluteExpiry int64 + + // CacheSize defines that a cache should be used for this interface and + // defines it's size. + // Caching comes with an important caveat: If database records are changed + // from another interface, the cache will not be invalidated for these + // records. It will therefore serve outdated data until that record is + // evicted from the cache. + CacheSize int + + // DelayCachedWrites defines a database name for which cache writes should + // be cached and batched. The database backend must support the Batcher + // interface. This option is only valid if used with a cache. + // Additionally, this may only be used for internal and local interfaces. + // Please note that this means that other interfaces will not be able to + // guarantee to serve the latest record if records are written this way. + DelayCachedWrites string +} + +// Apply applies options to the record metadata. +func (o *Options) Apply(r record.Record) { + r.UpdateMeta() + if o.AlwaysMakeSecret { + r.Meta().MakeSecret() + } + if o.AlwaysMakeCrownjewel { + r.Meta().MakeCrownJewel() + } + if o.AlwaysSetAbsoluteExpiry > 0 { + r.Meta().SetAbsoluteExpiry(o.AlwaysSetAbsoluteExpiry) + } else if o.AlwaysSetRelativateExpiry > 0 { + r.Meta().SetRelativateExpiry(o.AlwaysSetRelativateExpiry) + } +} + +// HasAllPermissions returns whether the options specify the highest possible +// permissions for operations. +func (o *Options) HasAllPermissions() bool { + return o.Local && o.Internal +} + +// hasAccessPermission checks if the interface options permit access to the +// given record, locking the record for accessing it's attributes. +func (o *Options) hasAccessPermission(r record.Record) bool { + // Check if the options specify all permissions, which makes checking the + // record unnecessary. + if o.HasAllPermissions() { + return true + } + + r.Lock() + defer r.Unlock() + + // Check permissions against record. + return r.Meta().CheckPermission(o.Local, o.Internal) +} + +// NewInterface returns a new Interface to the database. +func NewInterface(opts *Options) *Interface { + if opts == nil { + opts = &Options{} + } + + newIface := &Interface{ + options: opts, + } + if opts.CacheSize > 0 { + cacheBuilder := gcache.New(opts.CacheSize).ARC() + if opts.DelayCachedWrites != "" { + cacheBuilder.EvictedFunc(newIface.cacheEvictHandler) + newIface.writeCache = make(map[string]record.Record, opts.CacheSize/2) + newIface.triggerCacheWrite = make(chan struct{}) + } + newIface.cache = cacheBuilder.Build() + } + return newIface +} + +// Exists return whether a record with the given key exists. +func (i *Interface) Exists(key string) (bool, error) { + _, err := i.Get(key) + if err != nil { + switch { + case errors.Is(err, ErrNotFound): + return false, nil + case errors.Is(err, ErrPermissionDenied): + return true, nil + default: + return false, err + } + } + return true, nil +} + +// Get return the record with the given key. +func (i *Interface) Get(key string) (record.Record, error) { + r, _, err := i.getRecord(getDBFromKey, key, false) + return r, err +} + +func (i *Interface) getRecord(dbName string, dbKey string, mustBeWriteable bool) (r record.Record, db *Controller, err error) { //nolint:unparam + if dbName == "" { + dbName, dbKey = record.ParseKey(dbKey) + } + + db, err = getController(dbName) + if err != nil { + return nil, nil, err + } + + if mustBeWriteable && db.ReadOnly() { + return nil, db, ErrReadOnly + } + + r = i.checkCache(dbName + ":" + dbKey) + if r != nil { + if !i.options.hasAccessPermission(r) { + return nil, db, ErrPermissionDenied + } + return r, db, nil + } + + r, err = db.Get(dbKey) + if err != nil { + return nil, db, err + } + + if !i.options.hasAccessPermission(r) { + return nil, db, ErrPermissionDenied + } + + r.Lock() + ttl := r.Meta().GetRelativeExpiry() + r.Unlock() + i.updateCache( + r, + false, // writing + false, // remove + ttl, // expiry + ) + + return r, db, nil +} + +func (i *Interface) getMeta(dbName string, dbKey string, mustBeWriteable bool) (m *record.Meta, db *Controller, err error) { //nolint:unparam + if dbName == "" { + dbName, dbKey = record.ParseKey(dbKey) + } + + db, err = getController(dbName) + if err != nil { + return nil, nil, err + } + + if mustBeWriteable && db.ReadOnly() { + return nil, db, ErrReadOnly + } + + r := i.checkCache(dbName + ":" + dbKey) + if r != nil { + if !i.options.hasAccessPermission(r) { + return nil, db, ErrPermissionDenied + } + return r.Meta(), db, nil + } + + m, err = db.GetMeta(dbKey) + if err != nil { + return nil, db, err + } + + if !m.CheckPermission(i.options.Local, i.options.Internal) { + return nil, db, ErrPermissionDenied + } + + return m, db, nil +} + +// InsertValue inserts a value into a record. +func (i *Interface) InsertValue(key string, attribute string, value interface{}) error { + r, db, err := i.getRecord(getDBFromKey, key, true) + if err != nil { + return err + } + + r.Lock() + defer r.Unlock() + + var acc accessor.Accessor + if r.IsWrapped() { + wrapper, ok := r.(*record.Wrapper) + if !ok { + return errors.New("record is malformed (reports to be wrapped but is not of type *record.Wrapper)") + } + acc = accessor.NewJSONBytesAccessor(&wrapper.Data) + } else { + acc = accessor.NewStructAccessor(r) + } + + err = acc.Set(attribute, value) + if err != nil { + return fmt.Errorf("failed to set value with %s: %w", acc.Type(), err) + } + + i.options.Apply(r) + return db.Put(r) +} + +// Put saves a record to the database. +func (i *Interface) Put(r record.Record) (err error) { + // get record or only database + var db *Controller + if !i.options.HasAllPermissions() { + _, db, err = i.getMeta(r.DatabaseName(), r.DatabaseKey(), true) + if err != nil && !errors.Is(err, ErrNotFound) { + return err + } + } else { + db, err = getController(r.DatabaseName()) + if err != nil { + return err + } + } + + // Check if database is read only. + if db.ReadOnly() { + return ErrReadOnly + } + + r.Lock() + i.options.Apply(r) + remove := r.Meta().IsDeleted() + ttl := r.Meta().GetRelativeExpiry() + r.Unlock() + + // The record may not be locked when updating the cache. + written := i.updateCache(r, true, remove, ttl) + if written { + return nil + } + + r.Lock() + defer r.Unlock() + return db.Put(r) +} + +// PutNew saves a record to the database as a new record (ie. with new timestamps). +func (i *Interface) PutNew(r record.Record) (err error) { + // get record or only database + var db *Controller + if !i.options.HasAllPermissions() { + _, db, err = i.getMeta(r.DatabaseName(), r.DatabaseKey(), true) + if err != nil && !errors.Is(err, ErrNotFound) { + return err + } + } else { + db, err = getController(r.DatabaseName()) + if err != nil { + return err + } + } + + // Check if database is read only. + if db.ReadOnly() { + return ErrReadOnly + } + + r.Lock() + if r.Meta() != nil { + r.Meta().Reset() + } + i.options.Apply(r) + remove := r.Meta().IsDeleted() + ttl := r.Meta().GetRelativeExpiry() + r.Unlock() + + // The record may not be locked when updating the cache. + written := i.updateCache(r, true, remove, ttl) + if written { + return nil + } + + r.Lock() + defer r.Unlock() + return db.Put(r) +} + +// PutMany stores many records in the database. +// Warning: This is nearly a direct database access and omits many things: +// - Record locking +// - Hooks +// - Subscriptions +// - Caching +// Use with care. +func (i *Interface) PutMany(dbName string) (put func(record.Record) error) { + interfaceBatch := make(chan record.Record, 100) + + // permission check + if !i.options.HasAllPermissions() { + return func(r record.Record) error { + return ErrPermissionDenied + } + } + + // get database + db, err := getController(dbName) + if err != nil { + return func(r record.Record) error { + return err + } + } + + // Check if database is read only. + if db.ReadOnly() { + return func(r record.Record) error { + return ErrReadOnly + } + } + + // start database access + dbBatch, errs := db.PutMany() + finished := abool.New() + var internalErr error + + // interface options proxy + go func() { + defer close(dbBatch) // signify that we are finished + for { + select { + case r := <-interfaceBatch: + // finished? + if r == nil { + return + } + // apply options + i.options.Apply(r) + // pass along + dbBatch <- r + case <-time.After(1 * time.Second): + // bail out + internalErr = errors.New("timeout: putmany unused for too long") + finished.Set() + return + } + } + }() + + return func(r record.Record) error { + // finished? + if finished.IsSet() { + // check for internal error + if internalErr != nil { + return internalErr + } + // check for previous error + select { + case err := <-errs: + return err + default: + return errors.New("batch is closed") + } + } + + // finish? + if r == nil { + finished.Set() + interfaceBatch <- nil // signify that we are finished + // do not close, as this fn could be called again with nil. + return <-errs + } + + // check record scope + if r.DatabaseName() != dbName { + return errors.New("record out of database scope") + } + + // submit + select { + case interfaceBatch <- r: + return nil + case err := <-errs: + return err + } + } +} + +// SetAbsoluteExpiry sets an absolute record expiry. +func (i *Interface) SetAbsoluteExpiry(key string, time int64) error { + r, db, err := i.getRecord(getDBFromKey, key, true) + if err != nil { + return err + } + + r.Lock() + defer r.Unlock() + + i.options.Apply(r) + r.Meta().SetAbsoluteExpiry(time) + return db.Put(r) +} + +// SetRelativateExpiry sets a relative (self-updating) record expiry. +func (i *Interface) SetRelativateExpiry(key string, duration int64) error { + r, db, err := i.getRecord(getDBFromKey, key, true) + if err != nil { + return err + } + + r.Lock() + defer r.Unlock() + + i.options.Apply(r) + r.Meta().SetRelativateExpiry(duration) + return db.Put(r) +} + +// MakeSecret marks the record as a secret, meaning interfacing processes, such as an UI, are denied access to the record. +func (i *Interface) MakeSecret(key string) error { + r, db, err := i.getRecord(getDBFromKey, key, true) + if err != nil { + return err + } + + r.Lock() + defer r.Unlock() + + i.options.Apply(r) + r.Meta().MakeSecret() + return db.Put(r) +} + +// MakeCrownJewel marks a record as a crown jewel, meaning it will only be accessible locally. +func (i *Interface) MakeCrownJewel(key string) error { + r, db, err := i.getRecord(getDBFromKey, key, true) + if err != nil { + return err + } + + r.Lock() + defer r.Unlock() + + i.options.Apply(r) + r.Meta().MakeCrownJewel() + return db.Put(r) +} + +// Delete deletes a record from the database. +func (i *Interface) Delete(key string) error { + r, db, err := i.getRecord(getDBFromKey, key, true) + if err != nil { + return err + } + + // Check if database is read only. + if db.ReadOnly() { + return ErrReadOnly + } + + i.options.Apply(r) + r.Meta().Delete() + return db.Put(r) +} + +// Query executes the given query on the database. +// Will not see data that is in the write cache, waiting to be written. +// Use with care with caching. +func (i *Interface) Query(q *query.Query) (*iterator.Iterator, error) { + _, err := q.Check() + if err != nil { + return nil, err + } + + db, err := getController(q.DatabaseName()) + if err != nil { + return nil, err + } + + // TODO: Finish caching system integration. + // Flush the cache before we query the database. + // i.FlushCache() + + return db.Query(q, i.options.Local, i.options.Internal) +} + +// Purge deletes all records that match the given query. It returns the number +// of successful deletes and an error. +func (i *Interface) Purge(ctx context.Context, q *query.Query) (int, error) { + _, err := q.Check() + if err != nil { + return 0, err + } + + db, err := getController(q.DatabaseName()) + if err != nil { + return 0, err + } + + // Check if database is read only before we add to the cache. + if db.ReadOnly() { + return 0, ErrReadOnly + } + + return db.Purge(ctx, q, i.options.Local, i.options.Internal) +} + +// Subscribe subscribes to updates matching the given query. +func (i *Interface) Subscribe(q *query.Query) (*Subscription, error) { + _, err := q.Check() + if err != nil { + return nil, err + } + + c, err := getController(q.DatabaseName()) + if err != nil { + return nil, err + } + + sub := &Subscription{ + q: q, + local: i.options.Local, + internal: i.options.Internal, + Feed: make(chan record.Record, 1000), + } + c.addSubscription(sub) + return sub, nil +} diff --git a/base/database/interface_cache.go b/base/database/interface_cache.go new file mode 100644 index 000000000..06b16d174 --- /dev/null +++ b/base/database/interface_cache.go @@ -0,0 +1,227 @@ +package database + +import ( + "context" + "errors" + "time" + + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/log" +) + +// DelayedCacheWriter must be run by the caller of an interface that uses delayed cache writing. +func (i *Interface) DelayedCacheWriter(ctx context.Context) error { + // Check if the DelayedCacheWriter should be run at all. + if i.options.CacheSize <= 0 || i.options.DelayCachedWrites == "" { + return errors.New("delayed cache writer is not applicable to this database interface") + } + + // Check if backend support the Batcher interface. + batchPut := i.PutMany(i.options.DelayCachedWrites) + // End batchPut immediately and check for an error. + err := batchPut(nil) + if err != nil { + return err + } + + // percentThreshold defines the minimum percentage of entries in the write cache in relation to the cache size that need to be present in order for flushing the cache to the database storage. + percentThreshold := 25 + thresholdWriteTicker := time.NewTicker(5 * time.Second) + forceWriteTicker := time.NewTicker(5 * time.Minute) + + for { + // Wait for trigger for writing the cache. + select { + case <-ctx.Done(): + // The caller is shutting down, flush the cache to storage and exit. + i.flushWriteCache(0) + return nil + + case <-i.triggerCacheWrite: + // An entry from the cache was evicted that was also in the write cache. + // This makes it likely that other entries that are also present in the + // write cache will be evicted soon. Flush the write cache to storage + // immediately in order to reduce single writes. + i.flushWriteCache(0) + + case <-thresholdWriteTicker.C: + // Often check if the write cache has filled up to a certain degree and + // flush it to storage before we start evicting to-be-written entries and + // slow down the hot path again. + i.flushWriteCache(percentThreshold) + + case <-forceWriteTicker.C: + // Once in a while, flush the write cache to storage no matter how much + // it is filled. We don't want entries lingering around in the write + // cache forever. This also reduces the amount of data loss in the event + // of a total crash. + i.flushWriteCache(0) + } + } +} + +// ClearCache clears the read cache. +func (i *Interface) ClearCache() { + // Check if cache is in use. + if i.cache == nil { + return + } + + // Clear all cache entries. + i.cache.Purge() +} + +// FlushCache writes (and thus clears) the write cache. +func (i *Interface) FlushCache() { + // Check if write cache is in use. + if i.options.DelayCachedWrites != "" { + return + } + + i.flushWriteCache(0) +} + +func (i *Interface) flushWriteCache(percentThreshold int) { + i.writeCacheLock.Lock() + defer i.writeCacheLock.Unlock() + + // Check if there is anything to do. + if len(i.writeCache) == 0 { + return + } + + // Check if we reach the given threshold for writing to storage. + if (len(i.writeCache)*100)/i.options.CacheSize < percentThreshold { + return + } + + // Write the full cache in a batch operation. + batchPut := i.PutMany(i.options.DelayCachedWrites) + for _, r := range i.writeCache { + err := batchPut(r) + if err != nil { + log.Warningf("database: failed to write write-cached entry to %q database: %s", i.options.DelayCachedWrites, err) + } + } + // Finish batch. + err := batchPut(nil) + if err != nil { + log.Warningf("database: failed to finish flushing write cache to %q database: %s", i.options.DelayCachedWrites, err) + } + + // Optimized map clearing following the Go1.11 recommendation. + for key := range i.writeCache { + delete(i.writeCache, key) + } +} + +// cacheEvictHandler is run by the cache for every entry that gets evicted +// from the cache. +func (i *Interface) cacheEvictHandler(keyData, _ interface{}) { + // Transform the key into a string. + key, ok := keyData.(string) + if !ok { + return + } + + // Check if the evicted record is one that is to be written. + // Lock the write cache until the end of the function. + // The read cache is locked anyway for the whole duration. + i.writeCacheLock.Lock() + defer i.writeCacheLock.Unlock() + r, ok := i.writeCache[key] + if ok { + delete(i.writeCache, key) + } + if !ok { + return + } + + // Write record to database in order to mitigate race conditions where the record would appear + // as non-existent for a short duration. + db, err := getController(r.DatabaseName()) + if err != nil { + log.Warningf("database: failed to write evicted cache entry %q: database %q does not exist", key, r.DatabaseName()) + return + } + + r.Lock() + defer r.Unlock() + + err = db.Put(r) + if err != nil { + log.Warningf("database: failed to write evicted cache entry %q to database: %s", key, err) + } + + // Finally, trigger writing the full write cache because a to-be-written + // entry was just evicted from the cache, and this makes it likely that more + // to-be-written entries will be evicted shortly. + select { + case i.triggerCacheWrite <- struct{}{}: + default: + } +} + +func (i *Interface) checkCache(key string) record.Record { + // Check if cache is in use. + if i.cache == nil { + return nil + } + + // Check if record exists in cache. + cacheVal, err := i.cache.Get(key) + if err == nil { + r, ok := cacheVal.(record.Record) + if ok { + return r + } + } + return nil +} + +// updateCache updates an entry in the interface cache. The given record may +// not be locked, as updating the cache might write an (unrelated) evicted +// record to the database in the process. If this happens while the +// DelayedCacheWriter flushes the write cache with the same record present, +// this will deadlock. +func (i *Interface) updateCache(r record.Record, write bool, remove bool, ttl int64) (written bool) { + // Check if cache is in use. + if i.cache == nil { + return false + } + + // Check if record should be deleted + if remove { + // Remove entry from cache. + i.cache.Remove(r.Key()) + // Let write through to database storage. + return false + } + + // Update cache with record. + if ttl >= 0 { + _ = i.cache.SetWithExpire( + r.Key(), + r, + time.Duration(ttl)*time.Second, + ) + } else { + _ = i.cache.Set( + r.Key(), + r, + ) + } + + // Add record to write cache instead if: + // 1. The record is being written. + // 2. Write delaying is active. + // 3. Write delaying is active for the database of this record. + if write && r.DatabaseName() == i.options.DelayCachedWrites { + i.writeCacheLock.Lock() + defer i.writeCacheLock.Unlock() + i.writeCache[r.Key()] = r + return true + } + + return false +} diff --git a/base/database/interface_cache_test.go b/base/database/interface_cache_test.go new file mode 100644 index 000000000..cfed4388c --- /dev/null +++ b/base/database/interface_cache_test.go @@ -0,0 +1,156 @@ +package database + +import ( + "context" + "fmt" + "strconv" + "sync" + "testing" +) + +func benchmarkCacheWriting(b *testing.B, storageType string, cacheSize int, sampleSize int, delayWrites bool) { //nolint:gocognit,gocyclo,thelper + b.Run(fmt.Sprintf("CacheWriting_%s_%d_%d_%v", storageType, cacheSize, sampleSize, delayWrites), func(b *testing.B) { + // Setup Benchmark. + + // Create database. + dbName := fmt.Sprintf("cache-w-benchmark-%s-%d-%d-%v", storageType, cacheSize, sampleSize, delayWrites) + _, err := Register(&Database{ + Name: dbName, + Description: fmt.Sprintf("Cache Benchmark Database for %s", storageType), + StorageType: storageType, + }) + if err != nil { + b.Fatal(err) + } + + // Create benchmark interface. + options := &Options{ + Local: true, + Internal: true, + CacheSize: cacheSize, + } + if cacheSize > 0 && delayWrites { + options.DelayCachedWrites = dbName + } + db := NewInterface(options) + + // Start + ctx, cancelCtx := context.WithCancel(context.Background()) + var wg sync.WaitGroup + if cacheSize > 0 && delayWrites { + wg.Add(1) + go func() { + err := db.DelayedCacheWriter(ctx) + if err != nil { + panic(err) + } + wg.Done() + }() + } + + // Start Benchmark. + b.ResetTimer() + for i := 0; i < b.N; i++ { + testRecordID := i % sampleSize + r := NewExample( + dbName+":"+strconv.Itoa(testRecordID), + "A", + 1, + ) + err = db.Put(r) + if err != nil { + b.Fatal(err) + } + } + + // End cache writer and wait + cancelCtx() + wg.Wait() + }) +} + +func benchmarkCacheReadWrite(b *testing.B, storageType string, cacheSize int, sampleSize int, delayWrites bool) { //nolint:gocognit,gocyclo,thelper + b.Run(fmt.Sprintf("CacheReadWrite_%s_%d_%d_%v", storageType, cacheSize, sampleSize, delayWrites), func(b *testing.B) { + // Setup Benchmark. + + // Create database. + dbName := fmt.Sprintf("cache-rw-benchmark-%s-%d-%d-%v", storageType, cacheSize, sampleSize, delayWrites) + _, err := Register(&Database{ + Name: dbName, + Description: fmt.Sprintf("Cache Benchmark Database for %s", storageType), + StorageType: storageType, + }) + if err != nil { + b.Fatal(err) + } + + // Create benchmark interface. + options := &Options{ + Local: true, + Internal: true, + CacheSize: cacheSize, + } + if cacheSize > 0 && delayWrites { + options.DelayCachedWrites = dbName + } + db := NewInterface(options) + + // Start + ctx, cancelCtx := context.WithCancel(context.Background()) + var wg sync.WaitGroup + if cacheSize > 0 && delayWrites { + wg.Add(1) + go func() { + err := db.DelayedCacheWriter(ctx) + if err != nil { + panic(err) + } + wg.Done() + }() + } + + // Start Benchmark. + b.ResetTimer() + writing := true + for i := 0; i < b.N; i++ { + testRecordID := i % sampleSize + key := dbName + ":" + strconv.Itoa(testRecordID) + + if i > 0 && testRecordID == 0 { + writing = !writing // switch between reading and writing every samplesize + } + + if writing { + r := NewExample(key, "A", 1) + err = db.Put(r) + } else { + _, err = db.Get(key) + } + if err != nil { + b.Fatal(err) + } + } + + // End cache writer and wait + cancelCtx() + wg.Wait() + }) +} + +func BenchmarkCache(b *testing.B) { + for _, storageType := range []string{"bbolt", "hashmap"} { + benchmarkCacheWriting(b, storageType, 32, 8, false) + benchmarkCacheWriting(b, storageType, 32, 8, true) + benchmarkCacheWriting(b, storageType, 32, 1024, false) + benchmarkCacheWriting(b, storageType, 32, 1024, true) + benchmarkCacheWriting(b, storageType, 512, 1024, false) + benchmarkCacheWriting(b, storageType, 512, 1024, true) + + benchmarkCacheReadWrite(b, storageType, 32, 8, false) + benchmarkCacheReadWrite(b, storageType, 32, 8, true) + benchmarkCacheReadWrite(b, storageType, 32, 1024, false) + benchmarkCacheReadWrite(b, storageType, 32, 1024, true) + benchmarkCacheReadWrite(b, storageType, 512, 1024, false) + benchmarkCacheReadWrite(b, storageType, 512, 1024, true) + } +} diff --git a/base/database/iterator/iterator.go b/base/database/iterator/iterator.go new file mode 100644 index 000000000..b6dc86a1d --- /dev/null +++ b/base/database/iterator/iterator.go @@ -0,0 +1,54 @@ +package iterator + +import ( + "sync" + + "github.com/tevino/abool" + + "github.com/safing/portmaster/base/database/record" +) + +// Iterator defines the iterator structure. +type Iterator struct { + Next chan record.Record + Done chan struct{} + + errLock sync.Mutex + err error + doneClosed *abool.AtomicBool +} + +// New creates a new Iterator. +func New() *Iterator { + return &Iterator{ + Next: make(chan record.Record, 10), + Done: make(chan struct{}), + doneClosed: abool.NewBool(false), + } +} + +// Finish is called be the storage to signal the end of the query results. +func (it *Iterator) Finish(err error) { + close(it.Next) + if it.doneClosed.SetToIf(false, true) { + close(it.Done) + } + + it.errLock.Lock() + defer it.errLock.Unlock() + it.err = err +} + +// Cancel is called by the iteration consumer to cancel the running query. +func (it *Iterator) Cancel() { + if it.doneClosed.SetToIf(false, true) { + close(it.Done) + } +} + +// Err returns the iterator error, if exists. +func (it *Iterator) Err() error { + it.errLock.Lock() + defer it.errLock.Unlock() + return it.err +} diff --git a/base/database/main.go b/base/database/main.go new file mode 100644 index 000000000..9a03420db --- /dev/null +++ b/base/database/main.go @@ -0,0 +1,85 @@ +package database + +import ( + "errors" + "fmt" + "path/filepath" + + "github.com/tevino/abool" + + "github.com/safing/portmaster/base/utils" +) + +const ( + databasesSubDir = "databases" +) + +var ( + initialized = abool.NewBool(false) + + shuttingDown = abool.NewBool(false) + shutdownSignal = make(chan struct{}) + + rootStructure *utils.DirStructure + databasesStructure *utils.DirStructure +) + +// InitializeWithPath initializes the database at the specified location using a path. +func InitializeWithPath(dirPath string) error { + return Initialize(utils.NewDirStructure(dirPath, 0o0755)) +} + +// Initialize initializes the database at the specified location using a dir structure. +func Initialize(dirStructureRoot *utils.DirStructure) error { + if initialized.SetToIf(false, true) { + rootStructure = dirStructureRoot + + // ensure root and databases dirs + databasesStructure = rootStructure.ChildDir(databasesSubDir, 0o0700) + err := databasesStructure.Ensure() + if err != nil { + return fmt.Errorf("could not create/open database directory (%s): %w", rootStructure.Path, err) + } + + if registryPersistence.IsSet() { + err = loadRegistry() + if err != nil { + return fmt.Errorf("could not load database registry (%s): %w", filepath.Join(rootStructure.Path, registryFileName), err) + } + } + + return nil + } + return errors.New("database already initialized") +} + +// Shutdown shuts down the whole database system. +func Shutdown() (err error) { + if shuttingDown.SetToIf(false, true) { + close(shutdownSignal) + } else { + return + } + + controllersLock.RLock() + defer controllersLock.RUnlock() + + for _, c := range controllers { + err = c.Shutdown() + if err != nil { + return + } + } + return +} + +// getLocation returns the storage location for the given name and type. +func getLocation(name, storageType string) (string, error) { + location := databasesStructure.ChildDir(name, 0o0700).ChildDir(storageType, 0o0700) + // check location + err := location.Ensure() + if err != nil { + return "", fmt.Errorf(`failed to create/check database dir "%s": %w`, location.Path, err) + } + return location.Path, nil +} diff --git a/base/database/maintenance.go b/base/database/maintenance.go new file mode 100644 index 000000000..19c484fac --- /dev/null +++ b/base/database/maintenance.go @@ -0,0 +1,64 @@ +package database + +import ( + "context" + "time" +) + +// Maintain runs the Maintain method on all storages. +func Maintain(ctx context.Context) (err error) { + // copy, as we might use the very long + all := duplicateControllers() + + for _, c := range all { + err = c.Maintain(ctx) + if err != nil { + return + } + } + return +} + +// MaintainThorough runs the MaintainThorough method on all storages. +func MaintainThorough(ctx context.Context) (err error) { + // copy, as we might use the very long + all := duplicateControllers() + + for _, c := range all { + err = c.MaintainThorough(ctx) + if err != nil { + return + } + } + return +} + +// MaintainRecordStates runs record state lifecycle maintenance on all storages. +func MaintainRecordStates(ctx context.Context) (err error) { + // delete immediately for now + // TODO: increase purge threshold when starting to sync DBs + purgeDeletedBefore := time.Now().UTC() + + // copy, as we might use the very long + all := duplicateControllers() + + for _, c := range all { + err = c.MaintainRecordStates(ctx, purgeDeletedBefore) + if err != nil { + return + } + } + return +} + +func duplicateControllers() (all []*Controller) { + controllersLock.RLock() + defer controllersLock.RUnlock() + + all = make([]*Controller, 0, len(controllers)) + for _, c := range controllers { + all = append(all, c) + } + + return +} diff --git a/base/database/migration/error.go b/base/database/migration/error.go new file mode 100644 index 000000000..1ecb99f66 --- /dev/null +++ b/base/database/migration/error.go @@ -0,0 +1,58 @@ +package migration + +import "errors" + +// DiagnosticStep describes one migration step in the Diagnostics. +type DiagnosticStep struct { + Version string + Description string +} + +// Diagnostics holds a detailed error report about a failed migration. +type Diagnostics struct { //nolint:errname + // Message holds a human readable message of the encountered + // error. + Message string + // Wrapped must be set to the underlying error that was encountered + // while preparing or executing migrations. + Wrapped error + // StartOfMigration is set to the version of the database before + // any migrations are applied. + StartOfMigration string + // LastSuccessfulMigration is set to the version of the database + // which has been applied successfully before the error happened. + LastSuccessfulMigration string + // TargetVersion is set to the version of the database that the + // migration run aimed for. That is, it's the last available version + // added to the registry. + TargetVersion string + // ExecutionPlan is a list of migration steps that were planned to + // be executed. + ExecutionPlan []DiagnosticStep + // FailedMigration is the description of the migration that has + // failed. + FailedMigration string +} + +// Error returns a string representation of the migration error. +func (err *Diagnostics) Error() string { + msg := "" + if err.FailedMigration != "" { + msg = err.FailedMigration + ": " + } + if err.Message != "" { + msg += err.Message + ": " + } + msg += err.Wrapped.Error() + return msg +} + +// Unwrap returns the actual error that happened when executing +// a migration. It implements the interface required by the stdlib +// errors package to support errors.Is() and errors.As(). +func (err *Diagnostics) Unwrap() error { + if u := errors.Unwrap(err.Wrapped); u != nil { + return u + } + return err.Wrapped +} diff --git a/base/database/migration/migration.go b/base/database/migration/migration.go new file mode 100644 index 000000000..73b400839 --- /dev/null +++ b/base/database/migration/migration.go @@ -0,0 +1,220 @@ +package migration + +import ( + "context" + "errors" + "fmt" + "sort" + "sync" + "time" + + "github.com/hashicorp/go-version" + + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/base/log" +) + +// MigrateFunc is called when a migration should be applied to the +// database. It receives the current version (from) and the target +// version (to) of the database and a dedicated interface for +// interacting with data stored in the DB. +// A dedicated log.ContextTracer is added to ctx for each migration +// run. +type MigrateFunc func(ctx context.Context, from, to *version.Version, dbInterface *database.Interface) error + +// Migration represents a registered data-migration that should be applied to +// some database. Migrations are stacked on top and executed in order of increasing +// version number (see Version field). +type Migration struct { + // Description provides a short human-readable description of the + // migration. + Description string + // Version should hold the version of the database/subsystem after + // the migration has been applied. + Version string + // MigrateFuc is executed when the migration should be performed. + MigrateFunc MigrateFunc +} + +// Registry holds a migration stack. +type Registry struct { + key string + + lock sync.Mutex + migrations []Migration +} + +// New creates a new migration registry. +// The key should be the name of the database key that is used to store +// the version of the last successfully applied migration. +func New(key string) *Registry { + return &Registry{ + key: key, + } +} + +// Add adds one or more migrations to reg. +func (reg *Registry) Add(migrations ...Migration) error { + reg.lock.Lock() + defer reg.lock.Unlock() + for _, m := range migrations { + if _, err := version.NewSemver(m.Version); err != nil { + return fmt.Errorf("migration %q: invalid version %s: %w", m.Description, m.Version, err) + } + reg.migrations = append(reg.migrations, m) + } + return nil +} + +// Migrate migrates the database by executing all registered +// migration in order of increasing version numbers. The error +// returned, if not nil, is always of type *Diagnostics. +func (reg *Registry) Migrate(ctx context.Context) (err error) { + reg.lock.Lock() + defer reg.lock.Unlock() + + start := time.Now() + log.Infof("migration: migration of %s started", reg.key) + defer func() { + if err != nil { + log.Errorf("migration: migration of %s failed after %s: %s", reg.key, time.Since(start), err) + } else { + log.Infof("migration: migration of %s finished after %s", reg.key, time.Since(start)) + } + }() + + db := database.NewInterface(&database.Options{ + Local: true, + Internal: true, + }) + + startOfMigration, err := reg.getLatestSuccessfulMigration(db) + if err != nil { + return err + } + + execPlan, diag, err := reg.getExecutionPlan(startOfMigration) + if err != nil { + return err + } + if len(execPlan) == 0 { + return nil + } + diag.TargetVersion = execPlan[len(execPlan)-1].Version + + // finally, apply our migrations + lastAppliedMigration := startOfMigration + for _, m := range execPlan { + target, _ := version.NewSemver(m.Version) // we can safely ignore the error here + + migrationCtx, tracer := log.AddTracer(ctx) + + if err := m.MigrateFunc(migrationCtx, lastAppliedMigration, target, db); err != nil { + diag.Wrapped = err + diag.FailedMigration = m.Description + tracer.Errorf("migration: migration for %s failed: %s - %s", reg.key, target.String(), m.Description) + tracer.Submit() + return diag + } + + lastAppliedMigration = target + diag.LastSuccessfulMigration = lastAppliedMigration.String() + + if err := reg.saveLastSuccessfulMigration(db, target); err != nil { + diag.Message = "failed to persist migration status" + diag.Wrapped = err + diag.FailedMigration = m.Description + } + tracer.Infof("migration: applied migration for %s: %s - %s", reg.key, target.String(), m.Description) + tracer.Submit() + } + + // all migrations have been applied successfully, we're done here + return nil +} + +func (reg *Registry) getLatestSuccessfulMigration(db *database.Interface) (*version.Version, error) { + // find the latest version stored in the database + rec, err := db.Get(reg.key) + if errors.Is(err, database.ErrNotFound) { + return nil, nil + } + if err != nil { + return nil, &Diagnostics{ + Message: "failed to query database for migration status", + Wrapped: err, + } + } + + // Unwrap the record to get the actual database + r, ok := rec.(*record.Wrapper) + if !ok { + return nil, &Diagnostics{ + Wrapped: errors.New("expected wrapped database record"), + } + } + + sv, err := version.NewSemver(string(r.Data)) + if err != nil { + return nil, &Diagnostics{ + Message: "failed to parse version stored in migration status record", + Wrapped: err, + } + } + return sv, nil +} + +func (reg *Registry) saveLastSuccessfulMigration(db *database.Interface, ver *version.Version) error { + r := &record.Wrapper{ + Data: []byte(ver.String()), + Format: dsd.RAW, + } + r.SetKey(reg.key) + + return db.Put(r) +} + +func (reg *Registry) getExecutionPlan(startOfMigration *version.Version) ([]Migration, *Diagnostics, error) { + // create a look-up map for migrations indexed by their semver created a + // list of version (sorted by increasing number) that we use as our execution + // plan. + lm := make(map[string]Migration) + versions := make(version.Collection, 0, len(reg.migrations)) + for _, m := range reg.migrations { + ver, err := version.NewSemver(m.Version) + if err != nil { + return nil, nil, &Diagnostics{ + Message: "failed to parse version of migration", + Wrapped: err, + FailedMigration: m.Description, + } + } + lm[ver.String()] = m // use .String() for a normalized string representation + versions = append(versions, ver) + } + sort.Sort(versions) + + diag := new(Diagnostics) + if startOfMigration != nil { + diag.StartOfMigration = startOfMigration.String() + } + + // prepare our diagnostics and the execution plan + execPlan := make([]Migration, 0, len(versions)) + for _, ver := range versions { + // skip an migration that has already been applied. + if startOfMigration != nil && startOfMigration.GreaterThanOrEqual(ver) { + continue + } + m := lm[ver.String()] + diag.ExecutionPlan = append(diag.ExecutionPlan, DiagnosticStep{ + Description: m.Description, + Version: ver.String(), + }) + execPlan = append(execPlan, m) + } + + return execPlan, diag, nil +} diff --git a/base/database/query/README.md b/base/database/query/README.md new file mode 100644 index 000000000..455eb9557 --- /dev/null +++ b/base/database/query/README.md @@ -0,0 +1,55 @@ +# Query + +## Control Flow + +- Grouping with `(` and `)` +- Chaining with `and` and `or` + - _NO_ mixing! Be explicit and use grouping. +- Negation with `not` + - in front of expression for group: `not (...)` + - inside expression for clause: `name not matches "^King "` + +## Selectors + +Supported by all feeders: +- root level field: `field` +- sub level field: `field.sub` +- array/slice/map access: `map.0` +- array/slice/map length: `map.#` + +Please note that some feeders may have other special characters. It is advised to only use alphanumeric characters for keys. + +## Operators + +| Name | Textual | Req. Type | Internal Type | Compared with | +|-------------------------|--------------------|-----------|---------------|---------------------------| +| Equals | `==` | int | int64 | `==` | +| GreaterThan | `>` | int | int64 | `>` | +| GreaterThanOrEqual | `>=` | int | int64 | `>=` | +| LessThan | `<` | int | int64 | `<` | +| LessThanOrEqual | `<=` | int | int64 | `<=` | +| FloatEquals | `f==` | float | float64 | `==` | +| FloatGreaterThan | `f>` | float | float64 | `>` | +| FloatGreaterThanOrEqual | `f>=` | float | float64 | `>=` | +| FloatLessThan | `f<` | float | float64 | `<` | +| FloatLessThanOrEqual | `f<=` | float | float64 | `<=` | +| SameAs | `sameas`, `s==` | string | string | `==` | +| Contains | `contains`, `co` | string | string | `strings.Contains()` | +| StartsWith | `startswith`, `sw` | string | string | `strings.HasPrefix()` | +| EndsWith | `endswith`, `ew` | string | string | `strings.HasSuffix()` | +| In | `in` | string | string | for loop with `==` | +| Matches | `matches`, `re` | string | string | `regexp.Regexp.Matches()` | +| Is | `is` | bool* | bool | `==` | +| Exists | `exists`, `ex` | any | n/a | n/a | + +\*accepts strings: 1, t, T, true, True, TRUE, 0, f, F, false, False, FALSE + +## Escaping + +If you need to use a control character within a value (ie. not for controlling), escape it with `\`. +It is recommended to wrap a word into parenthesis instead of escaping control characters, when possible. + +| Location | Characters to be escaped | +|---|---| +| Within parenthesis (`"`) | `"`, `\` | +| Everywhere else | `(`, `)`, `"`, `\`, `\t`, `\r`, `\n`, ` ` (space) | diff --git a/base/database/query/condition-and.go b/base/database/query/condition-and.go new file mode 100644 index 000000000..473d4081a --- /dev/null +++ b/base/database/query/condition-and.go @@ -0,0 +1,46 @@ +package query + +import ( + "fmt" + "strings" + + "github.com/safing/portmaster/base/database/accessor" +) + +// And combines multiple conditions with a logical _AND_ operator. +func And(conditions ...Condition) Condition { + return &andCond{ + conditions: conditions, + } +} + +type andCond struct { + conditions []Condition +} + +func (c *andCond) complies(acc accessor.Accessor) bool { + for _, cond := range c.conditions { + if !cond.complies(acc) { + return false + } + } + return true +} + +func (c *andCond) check() (err error) { + for _, cond := range c.conditions { + err = cond.check() + if err != nil { + return err + } + } + return nil +} + +func (c *andCond) string() string { + all := make([]string, 0, len(c.conditions)) + for _, cond := range c.conditions { + all = append(all, cond.string()) + } + return fmt.Sprintf("(%s)", strings.Join(all, " and ")) +} diff --git a/base/database/query/condition-bool.go b/base/database/query/condition-bool.go new file mode 100644 index 000000000..79994ad69 --- /dev/null +++ b/base/database/query/condition-bool.go @@ -0,0 +1,69 @@ +package query + +import ( + "errors" + "fmt" + "strconv" + + "github.com/safing/portmaster/base/database/accessor" +) + +type boolCondition struct { + key string + operator uint8 + value bool +} + +func newBoolCondition(key string, operator uint8, value interface{}) *boolCondition { + var parsedValue bool + + switch v := value.(type) { + case bool: + parsedValue = v + case string: + var err error + parsedValue, err = strconv.ParseBool(v) + if err != nil { + return &boolCondition{ + key: fmt.Sprintf("could not parse \"%s\" to bool: %s", v, err), + operator: errorPresent, + } + } + default: + return &boolCondition{ + key: fmt.Sprintf("incompatible value %v for int64", value), + operator: errorPresent, + } + } + + return &boolCondition{ + key: key, + operator: operator, + value: parsedValue, + } +} + +func (c *boolCondition) complies(acc accessor.Accessor) bool { + comp, ok := acc.GetBool(c.key) + if !ok { + return false + } + + switch c.operator { + case Is: + return comp == c.value + default: + return false + } +} + +func (c *boolCondition) check() error { + if c.operator == errorPresent { + return errors.New(c.key) + } + return nil +} + +func (c *boolCondition) string() string { + return fmt.Sprintf("%s %s %t", escapeString(c.key), getOpName(c.operator), c.value) +} diff --git a/base/database/query/condition-error.go b/base/database/query/condition-error.go new file mode 100644 index 000000000..802d9217d --- /dev/null +++ b/base/database/query/condition-error.go @@ -0,0 +1,27 @@ +package query + +import ( + "github.com/safing/portmaster/base/database/accessor" +) + +type errorCondition struct { + err error +} + +func newErrorCondition(err error) *errorCondition { + return &errorCondition{ + err: err, + } +} + +func (c *errorCondition) complies(acc accessor.Accessor) bool { + return false +} + +func (c *errorCondition) check() error { + return c.err +} + +func (c *errorCondition) string() string { + return "[ERROR]" +} diff --git a/base/database/query/condition-exists.go b/base/database/query/condition-exists.go new file mode 100644 index 000000000..c5d62227f --- /dev/null +++ b/base/database/query/condition-exists.go @@ -0,0 +1,35 @@ +package query + +import ( + "errors" + "fmt" + + "github.com/safing/portmaster/base/database/accessor" +) + +type existsCondition struct { + key string + operator uint8 +} + +func newExistsCondition(key string, operator uint8) *existsCondition { + return &existsCondition{ + key: key, + operator: operator, + } +} + +func (c *existsCondition) complies(acc accessor.Accessor) bool { + return acc.Exists(c.key) +} + +func (c *existsCondition) check() error { + if c.operator == errorPresent { + return errors.New(c.key) + } + return nil +} + +func (c *existsCondition) string() string { + return fmt.Sprintf("%s %s", escapeString(c.key), getOpName(c.operator)) +} diff --git a/base/database/query/condition-float.go b/base/database/query/condition-float.go new file mode 100644 index 000000000..8ec4418a3 --- /dev/null +++ b/base/database/query/condition-float.go @@ -0,0 +1,97 @@ +package query + +import ( + "errors" + "fmt" + "strconv" + + "github.com/safing/portmaster/base/database/accessor" +) + +type floatCondition struct { + key string + operator uint8 + value float64 +} + +func newFloatCondition(key string, operator uint8, value interface{}) *floatCondition { + var parsedValue float64 + + switch v := value.(type) { + case int: + parsedValue = float64(v) + case int8: + parsedValue = float64(v) + case int16: + parsedValue = float64(v) + case int32: + parsedValue = float64(v) + case int64: + parsedValue = float64(v) + case uint: + parsedValue = float64(v) + case uint8: + parsedValue = float64(v) + case uint16: + parsedValue = float64(v) + case uint32: + parsedValue = float64(v) + case float32: + parsedValue = float64(v) + case float64: + parsedValue = v + case string: + var err error + parsedValue, err = strconv.ParseFloat(v, 64) + if err != nil { + return &floatCondition{ + key: fmt.Sprintf("could not parse %s to float64: %s", v, err), + operator: errorPresent, + } + } + default: + return &floatCondition{ + key: fmt.Sprintf("incompatible value %v for float64", value), + operator: errorPresent, + } + } + + return &floatCondition{ + key: key, + operator: operator, + value: parsedValue, + } +} + +func (c *floatCondition) complies(acc accessor.Accessor) bool { + comp, ok := acc.GetFloat(c.key) + if !ok { + return false + } + + switch c.operator { + case FloatEquals: + return comp == c.value + case FloatGreaterThan: + return comp > c.value + case FloatGreaterThanOrEqual: + return comp >= c.value + case FloatLessThan: + return comp < c.value + case FloatLessThanOrEqual: + return comp <= c.value + default: + return false + } +} + +func (c *floatCondition) check() error { + if c.operator == errorPresent { + return errors.New(c.key) + } + return nil +} + +func (c *floatCondition) string() string { + return fmt.Sprintf("%s %s %g", escapeString(c.key), getOpName(c.operator), c.value) +} diff --git a/base/database/query/condition-int.go b/base/database/query/condition-int.go new file mode 100644 index 000000000..cdd8ecccf --- /dev/null +++ b/base/database/query/condition-int.go @@ -0,0 +1,93 @@ +package query + +import ( + "errors" + "fmt" + "strconv" + + "github.com/safing/portmaster/base/database/accessor" +) + +type intCondition struct { + key string + operator uint8 + value int64 +} + +func newIntCondition(key string, operator uint8, value interface{}) *intCondition { + var parsedValue int64 + + switch v := value.(type) { + case int: + parsedValue = int64(v) + case int8: + parsedValue = int64(v) + case int16: + parsedValue = int64(v) + case int32: + parsedValue = int64(v) + case int64: + parsedValue = v + case uint: + parsedValue = int64(v) + case uint8: + parsedValue = int64(v) + case uint16: + parsedValue = int64(v) + case uint32: + parsedValue = int64(v) + case string: + var err error + parsedValue, err = strconv.ParseInt(v, 10, 64) + if err != nil { + return &intCondition{ + key: fmt.Sprintf("could not parse %s to int64: %s (hint: use \"sameas\" to compare strings)", v, err), + operator: errorPresent, + } + } + default: + return &intCondition{ + key: fmt.Sprintf("incompatible value %v for int64", value), + operator: errorPresent, + } + } + + return &intCondition{ + key: key, + operator: operator, + value: parsedValue, + } +} + +func (c *intCondition) complies(acc accessor.Accessor) bool { + comp, ok := acc.GetInt(c.key) + if !ok { + return false + } + + switch c.operator { + case Equals: + return comp == c.value + case GreaterThan: + return comp > c.value + case GreaterThanOrEqual: + return comp >= c.value + case LessThan: + return comp < c.value + case LessThanOrEqual: + return comp <= c.value + default: + return false + } +} + +func (c *intCondition) check() error { + if c.operator == errorPresent { + return errors.New(c.key) + } + return nil +} + +func (c *intCondition) string() string { + return fmt.Sprintf("%s %s %d", escapeString(c.key), getOpName(c.operator), c.value) +} diff --git a/base/database/query/condition-not.go b/base/database/query/condition-not.go new file mode 100644 index 000000000..ac17f35a2 --- /dev/null +++ b/base/database/query/condition-not.go @@ -0,0 +1,36 @@ +package query + +import ( + "fmt" + "strings" + + "github.com/safing/portmaster/base/database/accessor" +) + +// Not negates the supplied condition. +func Not(c Condition) Condition { + return ¬Cond{ + notC: c, + } +} + +type notCond struct { + notC Condition +} + +func (c *notCond) complies(acc accessor.Accessor) bool { + return !c.notC.complies(acc) +} + +func (c *notCond) check() error { + return c.notC.check() +} + +func (c *notCond) string() string { + next := c.notC.string() + if strings.HasPrefix(next, "(") { + return fmt.Sprintf("not %s", c.notC.string()) + } + splitted := strings.Split(next, " ") + return strings.Join(append([]string{splitted[0], "not"}, splitted[1:]...), " ") +} diff --git a/base/database/query/condition-or.go b/base/database/query/condition-or.go new file mode 100644 index 000000000..8dffe6f26 --- /dev/null +++ b/base/database/query/condition-or.go @@ -0,0 +1,46 @@ +package query + +import ( + "fmt" + "strings" + + "github.com/safing/portmaster/base/database/accessor" +) + +// Or combines multiple conditions with a logical _OR_ operator. +func Or(conditions ...Condition) Condition { + return &orCond{ + conditions: conditions, + } +} + +type orCond struct { + conditions []Condition +} + +func (c *orCond) complies(acc accessor.Accessor) bool { + for _, cond := range c.conditions { + if cond.complies(acc) { + return true + } + } + return false +} + +func (c *orCond) check() (err error) { + for _, cond := range c.conditions { + err = cond.check() + if err != nil { + return err + } + } + return nil +} + +func (c *orCond) string() string { + all := make([]string, 0, len(c.conditions)) + for _, cond := range c.conditions { + all = append(all, cond.string()) + } + return fmt.Sprintf("(%s)", strings.Join(all, " or ")) +} diff --git a/base/database/query/condition-regex.go b/base/database/query/condition-regex.go new file mode 100644 index 000000000..06937fb3e --- /dev/null +++ b/base/database/query/condition-regex.go @@ -0,0 +1,63 @@ +package query + +import ( + "errors" + "fmt" + "regexp" + + "github.com/safing/portmaster/base/database/accessor" +) + +type regexCondition struct { + key string + operator uint8 + regex *regexp.Regexp +} + +func newRegexCondition(key string, operator uint8, value interface{}) *regexCondition { + switch v := value.(type) { + case string: + r, err := regexp.Compile(v) + if err != nil { + return ®exCondition{ + key: fmt.Sprintf("could not compile regex \"%s\": %s", v, err), + operator: errorPresent, + } + } + return ®exCondition{ + key: key, + operator: operator, + regex: r, + } + default: + return ®exCondition{ + key: fmt.Sprintf("incompatible value %v for string", value), + operator: errorPresent, + } + } +} + +func (c *regexCondition) complies(acc accessor.Accessor) bool { + comp, ok := acc.GetString(c.key) + if !ok { + return false + } + + switch c.operator { + case Matches: + return c.regex.MatchString(comp) + default: + return false + } +} + +func (c *regexCondition) check() error { + if c.operator == errorPresent { + return errors.New(c.key) + } + return nil +} + +func (c *regexCondition) string() string { + return fmt.Sprintf("%s %s %s", escapeString(c.key), getOpName(c.operator), escapeString(c.regex.String())) +} diff --git a/base/database/query/condition-string.go b/base/database/query/condition-string.go new file mode 100644 index 000000000..8a308d2ba --- /dev/null +++ b/base/database/query/condition-string.go @@ -0,0 +1,62 @@ +package query + +import ( + "errors" + "fmt" + "strings" + + "github.com/safing/portmaster/base/database/accessor" +) + +type stringCondition struct { + key string + operator uint8 + value string +} + +func newStringCondition(key string, operator uint8, value interface{}) *stringCondition { + switch v := value.(type) { + case string: + return &stringCondition{ + key: key, + operator: operator, + value: v, + } + default: + return &stringCondition{ + key: fmt.Sprintf("incompatible value %v for string", value), + operator: errorPresent, + } + } +} + +func (c *stringCondition) complies(acc accessor.Accessor) bool { + comp, ok := acc.GetString(c.key) + if !ok { + return false + } + + switch c.operator { + case SameAs: + return c.value == comp + case Contains: + return strings.Contains(comp, c.value) + case StartsWith: + return strings.HasPrefix(comp, c.value) + case EndsWith: + return strings.HasSuffix(comp, c.value) + default: + return false + } +} + +func (c *stringCondition) check() error { + if c.operator == errorPresent { + return errors.New(c.key) + } + return nil +} + +func (c *stringCondition) string() string { + return fmt.Sprintf("%s %s %s", escapeString(c.key), getOpName(c.operator), escapeString(c.value)) +} diff --git a/base/database/query/condition-stringslice.go b/base/database/query/condition-stringslice.go new file mode 100644 index 000000000..ade2692fe --- /dev/null +++ b/base/database/query/condition-stringslice.go @@ -0,0 +1,69 @@ +package query + +import ( + "fmt" + "strings" + + "github.com/safing/portmaster/base/database/accessor" + "github.com/safing/portmaster/base/utils" +) + +type stringSliceCondition struct { + key string + operator uint8 + value []string +} + +func newStringSliceCondition(key string, operator uint8, value interface{}) *stringSliceCondition { + switch v := value.(type) { + case string: + parsedValue := strings.Split(v, ",") + if len(parsedValue) < 2 { + return &stringSliceCondition{ + key: v, + operator: errorPresent, + } + } + return &stringSliceCondition{ + key: key, + operator: operator, + value: parsedValue, + } + case []string: + return &stringSliceCondition{ + key: key, + operator: operator, + value: v, + } + default: + return &stringSliceCondition{ + key: fmt.Sprintf("incompatible value %v for []string", value), + operator: errorPresent, + } + } +} + +func (c *stringSliceCondition) complies(acc accessor.Accessor) bool { + comp, ok := acc.GetString(c.key) + if !ok { + return false + } + + switch c.operator { + case In: + return utils.StringInSlice(c.value, comp) + default: + return false + } +} + +func (c *stringSliceCondition) check() error { + if c.operator == errorPresent { + return fmt.Errorf("could not parse \"%s\" to []string", c.key) + } + return nil +} + +func (c *stringSliceCondition) string() string { + return fmt.Sprintf("%s %s %s", escapeString(c.key), getOpName(c.operator), escapeString(strings.Join(c.value, ","))) +} diff --git a/base/database/query/condition.go b/base/database/query/condition.go new file mode 100644 index 000000000..dbedf17a9 --- /dev/null +++ b/base/database/query/condition.go @@ -0,0 +1,71 @@ +package query + +import ( + "fmt" + + "github.com/safing/portmaster/base/database/accessor" +) + +// Condition is an interface to provide a common api to all condition types. +type Condition interface { + complies(acc accessor.Accessor) bool + check() error + string() string +} + +// Operators. +const ( + Equals uint8 = iota // int + GreaterThan // int + GreaterThanOrEqual // int + LessThan // int + LessThanOrEqual // int + FloatEquals // float + FloatGreaterThan // float + FloatGreaterThanOrEqual // float + FloatLessThan // float + FloatLessThanOrEqual // float + SameAs // string + Contains // string + StartsWith // string + EndsWith // string + In // stringSlice + Matches // regex + Is // bool: accepts 1, t, T, TRUE, true, True, 0, f, F, FALSE + Exists // any + + errorPresent uint8 = 255 +) + +// Where returns a condition to add to a query. +func Where(key string, operator uint8, value interface{}) Condition { + switch operator { + case Equals, + GreaterThan, + GreaterThanOrEqual, + LessThan, + LessThanOrEqual: + return newIntCondition(key, operator, value) + case FloatEquals, + FloatGreaterThan, + FloatGreaterThanOrEqual, + FloatLessThan, + FloatLessThanOrEqual: + return newFloatCondition(key, operator, value) + case SameAs, + Contains, + StartsWith, + EndsWith: + return newStringCondition(key, operator, value) + case In: + return newStringSliceCondition(key, operator, value) + case Matches: + return newRegexCondition(key, operator, value) + case Is: + return newBoolCondition(key, operator, value) + case Exists: + return newExistsCondition(key, operator) + default: + return newErrorCondition(fmt.Errorf("no operator with ID %d", operator)) + } +} diff --git a/base/database/query/condition_test.go b/base/database/query/condition_test.go new file mode 100644 index 000000000..0ce2e55ba --- /dev/null +++ b/base/database/query/condition_test.go @@ -0,0 +1,86 @@ +package query + +import "testing" + +func testSuccess(t *testing.T, c Condition) { + t.Helper() + + err := c.check() + if err != nil { + t.Errorf("failed: %s", err) + } +} + +func TestInterfaces(t *testing.T) { + t.Parallel() + + testSuccess(t, newIntCondition("banana", Equals, uint(1))) + testSuccess(t, newIntCondition("banana", Equals, uint8(1))) + testSuccess(t, newIntCondition("banana", Equals, uint16(1))) + testSuccess(t, newIntCondition("banana", Equals, uint32(1))) + testSuccess(t, newIntCondition("banana", Equals, int(1))) + testSuccess(t, newIntCondition("banana", Equals, int8(1))) + testSuccess(t, newIntCondition("banana", Equals, int16(1))) + testSuccess(t, newIntCondition("banana", Equals, int32(1))) + testSuccess(t, newIntCondition("banana", Equals, int64(1))) + testSuccess(t, newIntCondition("banana", Equals, "1")) + + testSuccess(t, newFloatCondition("banana", FloatEquals, uint(1))) + testSuccess(t, newFloatCondition("banana", FloatEquals, uint8(1))) + testSuccess(t, newFloatCondition("banana", FloatEquals, uint16(1))) + testSuccess(t, newFloatCondition("banana", FloatEquals, uint32(1))) + testSuccess(t, newFloatCondition("banana", FloatEquals, int(1))) + testSuccess(t, newFloatCondition("banana", FloatEquals, int8(1))) + testSuccess(t, newFloatCondition("banana", FloatEquals, int16(1))) + testSuccess(t, newFloatCondition("banana", FloatEquals, int32(1))) + testSuccess(t, newFloatCondition("banana", FloatEquals, int64(1))) + testSuccess(t, newFloatCondition("banana", FloatEquals, float32(1))) + testSuccess(t, newFloatCondition("banana", FloatEquals, float64(1))) + testSuccess(t, newFloatCondition("banana", FloatEquals, "1.1")) + + testSuccess(t, newStringCondition("banana", SameAs, "coconut")) + testSuccess(t, newRegexCondition("banana", Matches, "coconut")) + testSuccess(t, newStringSliceCondition("banana", FloatEquals, []string{"banana", "coconut"})) + testSuccess(t, newStringSliceCondition("banana", FloatEquals, "banana,coconut")) +} + +func testCondError(t *testing.T, c Condition) { + t.Helper() + + err := c.check() + if err == nil { + t.Error("should fail") + } +} + +func TestConditionErrors(t *testing.T) { + t.Parallel() + + // test invalid value types + testCondError(t, newBoolCondition("banana", Is, 1)) + testCondError(t, newFloatCondition("banana", FloatEquals, true)) + testCondError(t, newIntCondition("banana", Equals, true)) + testCondError(t, newStringCondition("banana", SameAs, 1)) + testCondError(t, newRegexCondition("banana", Matches, 1)) + testCondError(t, newStringSliceCondition("banana", Matches, 1)) + + // test error presence + testCondError(t, newBoolCondition("banana", errorPresent, true)) + testCondError(t, And(newBoolCondition("banana", errorPresent, true))) + testCondError(t, Or(newBoolCondition("banana", errorPresent, true))) + testCondError(t, newExistsCondition("banana", errorPresent)) + testCondError(t, newFloatCondition("banana", errorPresent, 1.1)) + testCondError(t, newIntCondition("banana", errorPresent, 1)) + testCondError(t, newStringCondition("banana", errorPresent, "coconut")) + testCondError(t, newRegexCondition("banana", errorPresent, "coconut")) +} + +func TestWhere(t *testing.T) { + t.Parallel() + + c := Where("", 254, nil) + err := c.check() + if err == nil { + t.Error("should fail") + } +} diff --git a/base/database/query/operators.go b/base/database/query/operators.go new file mode 100644 index 000000000..bbd21ee44 --- /dev/null +++ b/base/database/query/operators.go @@ -0,0 +1,53 @@ +package query + +var ( + operatorNames = map[string]uint8{ + "==": Equals, + ">": GreaterThan, + ">=": GreaterThanOrEqual, + "<": LessThan, + "<=": LessThanOrEqual, + "f==": FloatEquals, + "f>": FloatGreaterThan, + "f>=": FloatGreaterThanOrEqual, + "f<": FloatLessThan, + "f<=": FloatLessThanOrEqual, + "sameas": SameAs, + "s==": SameAs, + "contains": Contains, + "co": Contains, + "startswith": StartsWith, + "sw": StartsWith, + "endswith": EndsWith, + "ew": EndsWith, + "in": In, + "matches": Matches, + "re": Matches, + "is": Is, + "exists": Exists, + "ex": Exists, + } + + primaryNames = make(map[uint8]string) +) + +func init() { + for opName, opID := range operatorNames { + name, ok := primaryNames[opID] + if ok { + if len(name) < len(opName) { + primaryNames[opID] = opName + } + } else { + primaryNames[opID] = opName + } + } +} + +func getOpName(operator uint8) string { + name, ok := primaryNames[operator] + if ok { + return name + } + return "[unknown]" +} diff --git a/base/database/query/operators_test.go b/base/database/query/operators_test.go new file mode 100644 index 000000000..9fa844107 --- /dev/null +++ b/base/database/query/operators_test.go @@ -0,0 +1,11 @@ +package query + +import "testing" + +func TestGetOpName(t *testing.T) { + t.Parallel() + + if getOpName(254) != "[unknown]" { + t.Error("unexpected output") + } +} diff --git a/base/database/query/parser.go b/base/database/query/parser.go new file mode 100644 index 000000000..b6abd390c --- /dev/null +++ b/base/database/query/parser.go @@ -0,0 +1,350 @@ +package query + +import ( + "errors" + "fmt" + "regexp" + "strconv" + "strings" +) + +type snippet struct { + text string + globalPosition int +} + +// ParseQuery parses a plaintext query. Special characters (that must be escaped with a '\') are: `\()` and any whitespaces. +// +//nolint:gocognit +func ParseQuery(query string) (*Query, error) { + snippets, err := extractSnippets(query) + if err != nil { + return nil, err + } + snippetsPos := 0 + + getSnippet := func() (*snippet, error) { + // order is important, as parseAndOr will always consume one additional snippet. + snippetsPos++ + if snippetsPos > len(snippets) { + return nil, fmt.Errorf("unexpected end at position %d", len(query)) + } + return snippets[snippetsPos-1], nil + } + remainingSnippets := func() int { + return len(snippets) - snippetsPos + } + + // check for query word + queryWord, err := getSnippet() + if err != nil { + return nil, err + } + if queryWord.text != "query" { + return nil, errors.New("queries must start with \"query\"") + } + + // get prefix + prefix, err := getSnippet() + if err != nil { + return nil, err + } + q := New(prefix.text) + + for remainingSnippets() > 0 { + command, err := getSnippet() + if err != nil { + return nil, err + } + + switch command.text { + case "where": + if q.where != nil { + return nil, fmt.Errorf("duplicate \"%s\" clause found at position %d", command.text, command.globalPosition) + } + + // parse conditions + condition, err := parseAndOr(getSnippet, remainingSnippets, true) + if err != nil { + return nil, err + } + // go one back, as parseAndOr had to check if its done + snippetsPos-- + + q.Where(condition) + case "orderby": + if q.orderBy != "" { + return nil, fmt.Errorf("duplicate \"%s\" clause found at position %d", command.text, command.globalPosition) + } + + orderBySnippet, err := getSnippet() + if err != nil { + return nil, err + } + + q.OrderBy(orderBySnippet.text) + case "limit": + if q.limit != 0 { + return nil, fmt.Errorf("duplicate \"%s\" clause found at position %d", command.text, command.globalPosition) + } + + limitSnippet, err := getSnippet() + if err != nil { + return nil, err + } + limit, err := strconv.ParseUint(limitSnippet.text, 10, 31) + if err != nil { + return nil, fmt.Errorf("could not parse integer (%s) at position %d", limitSnippet.text, limitSnippet.globalPosition) + } + + q.Limit(int(limit)) + case "offset": + if q.offset != 0 { + return nil, fmt.Errorf("duplicate \"%s\" clause found at position %d", command.text, command.globalPosition) + } + + offsetSnippet, err := getSnippet() + if err != nil { + return nil, err + } + offset, err := strconv.ParseUint(offsetSnippet.text, 10, 31) + if err != nil { + return nil, fmt.Errorf("could not parse integer (%s) at position %d", offsetSnippet.text, offsetSnippet.globalPosition) + } + + q.Offset(int(offset)) + default: + return nil, fmt.Errorf("unknown clause \"%s\" at position %d", command.text, command.globalPosition) + } + } + + return q.Check() +} + +func extractSnippets(text string) (snippets []*snippet, err error) { + skip := false + start := -1 + inParenthesis := false + var pos int + var char rune + + for pos, char = range text { + + // skip + if skip { + skip = false + continue + } + if char == '\\' { + skip = true + } + + // wait for parenthesis to be overs + if inParenthesis { + if char == '"' { + snippets = append(snippets, &snippet{ + text: prepToken(text[start+1 : pos]), + globalPosition: start + 1, + }) + start = -1 + inParenthesis = false + } + continue + } + + // handle segments + switch char { + case '\t', '\n', '\r', ' ', '(', ')': + if start >= 0 { + snippets = append(snippets, &snippet{ + text: prepToken(text[start:pos]), + globalPosition: start + 1, + }) + start = -1 + } + default: + if start == -1 { + start = pos + } + } + + // handle special segment characters + switch char { + case '(', ')': + snippets = append(snippets, &snippet{ + text: text[pos : pos+1], + globalPosition: pos + 1, + }) + case '"': + if start < pos { + return nil, fmt.Errorf("parenthesis ('\"') may not be used within words, please escape with '\\' (position: %d)", pos+1) + } + inParenthesis = true + } + + } + + // add last + if start >= 0 { + snippets = append(snippets, &snippet{ + text: prepToken(text[start : pos+1]), + globalPosition: start + 1, + }) + } + + return snippets, nil +} + +//nolint:gocognit +func parseAndOr(getSnippet func() (*snippet, error), remainingSnippets func() int, rootCondition bool) (Condition, error) { + var ( + isOr = false + typeSet = false + wrapInNot = false + expectingMore = true + conditions []Condition + ) + + for { + if !expectingMore && rootCondition && remainingSnippets() == 0 { + // advance snippetsPos by one, as it will be set back by 1 + _, _ = getSnippet() + if len(conditions) == 1 { + return conditions[0], nil + } + if isOr { + return Or(conditions...), nil + } + return And(conditions...), nil + } + + firstSnippet, err := getSnippet() + if err != nil { + return nil, err + } + + if !expectingMore && rootCondition { + switch firstSnippet.text { + case "orderby", "limit", "offset": + if len(conditions) == 1 { + return conditions[0], nil + } + if isOr { + return Or(conditions...), nil + } + return And(conditions...), nil + } + } + + switch firstSnippet.text { + case "(": + condition, err := parseAndOr(getSnippet, remainingSnippets, false) + if err != nil { + return nil, err + } + if wrapInNot { + conditions = append(conditions, Not(condition)) + wrapInNot = false + } else { + conditions = append(conditions, condition) + } + expectingMore = true + case ")": + if len(conditions) == 1 { + return conditions[0], nil + } + if isOr { + return Or(conditions...), nil + } + return And(conditions...), nil + case "and": + if typeSet && isOr { + return nil, fmt.Errorf("you may not mix \"and\" and \"or\" (position: %d)", firstSnippet.globalPosition) + } + isOr = false + typeSet = true + expectingMore = true + case "or": + if typeSet && !isOr { + return nil, fmt.Errorf("you may not mix \"and\" and \"or\" (position: %d)", firstSnippet.globalPosition) + } + isOr = true + typeSet = true + expectingMore = true + case "not": + wrapInNot = true + expectingMore = true + default: + condition, err := parseCondition(firstSnippet, getSnippet) + if err != nil { + return nil, err + } + if wrapInNot { + conditions = append(conditions, Not(condition)) + wrapInNot = false + } else { + conditions = append(conditions, condition) + } + expectingMore = false + } + } +} + +func parseCondition(firstSnippet *snippet, getSnippet func() (*snippet, error)) (Condition, error) { + wrapInNot := false + + // get operator name + opName, err := getSnippet() + if err != nil { + return nil, err + } + // negate? + if opName.text == "not" { + wrapInNot = true + opName, err = getSnippet() + if err != nil { + return nil, err + } + } + + // get operator + operator, ok := operatorNames[opName.text] + if !ok { + return nil, fmt.Errorf("unknown operator at position %d", opName.globalPosition) + } + + // don't need a value for "exists" + if operator == Exists { + if wrapInNot { + return Not(Where(firstSnippet.text, operator, nil)), nil + } + return Where(firstSnippet.text, operator, nil), nil + } + + // get value + value, err := getSnippet() + if err != nil { + return nil, err + } + if wrapInNot { + return Not(Where(firstSnippet.text, operator, value.text)), nil + } + return Where(firstSnippet.text, operator, value.text), nil +} + +var escapeReplacer = regexp.MustCompile(`\\([^\\])`) + +// prepToken removes surrounding parenthesis and escape characters. +func prepToken(text string) string { + return escapeReplacer.ReplaceAllString(strings.Trim(text, "\""), "$1") +} + +// escapeString correctly escapes a snippet for printing. +func escapeString(token string) string { + // check if token contains characters that need to be escaped + if strings.ContainsAny(token, "()\"\\\t\r\n ") { + // put the token in parenthesis and only escape \ and " + return fmt.Sprintf("\"%s\"", strings.ReplaceAll(token, "\"", "\\\"")) + } + return token +} diff --git a/base/database/query/parser_test.go b/base/database/query/parser_test.go new file mode 100644 index 000000000..fb30ad824 --- /dev/null +++ b/base/database/query/parser_test.go @@ -0,0 +1,177 @@ +package query + +import ( + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" +) + +func TestExtractSnippets(t *testing.T) { + t.Parallel() + + text1 := `query test: where ( "bananas" > 100 and monkeys.# <= "12")or(coconuts < 10 "and" area > 50) or name sameas Julian or name matches ^King\ ` + result1 := []*snippet{ + {text: "query", globalPosition: 1}, + {text: "test:", globalPosition: 7}, + {text: "where", globalPosition: 13}, + {text: "(", globalPosition: 19}, + {text: "bananas", globalPosition: 21}, + {text: ">", globalPosition: 31}, + {text: "100", globalPosition: 33}, + {text: "and", globalPosition: 37}, + {text: "monkeys.#", globalPosition: 41}, + {text: "<=", globalPosition: 51}, + {text: "12", globalPosition: 54}, + {text: ")", globalPosition: 58}, + {text: "or", globalPosition: 59}, + {text: "(", globalPosition: 61}, + {text: "coconuts", globalPosition: 62}, + {text: "<", globalPosition: 71}, + {text: "10", globalPosition: 73}, + {text: "and", globalPosition: 76}, + {text: "area", globalPosition: 82}, + {text: ">", globalPosition: 87}, + {text: "50", globalPosition: 89}, + {text: ")", globalPosition: 91}, + {text: "or", globalPosition: 93}, + {text: "name", globalPosition: 96}, + {text: "sameas", globalPosition: 101}, + {text: "Julian", globalPosition: 108}, + {text: "or", globalPosition: 115}, + {text: "name", globalPosition: 118}, + {text: "matches", globalPosition: 123}, + {text: "^King ", globalPosition: 131}, + } + + snippets, err := extractSnippets(text1) + if err != nil { + t.Errorf("failed to extract snippets: %s", err) + } + + if !reflect.DeepEqual(result1, snippets) { + t.Errorf("unexpected results:") + for _, el := range snippets { + t.Errorf("%+v", el) + } + } + + // t.Error(spew.Sprintf("%v", treeElement)) +} + +func testParsing(t *testing.T, queryText string, expectedResult *Query) { + t.Helper() + + _, err := expectedResult.Check() + if err != nil { + t.Errorf("failed to create query: %s", err) + return + } + + q, err := ParseQuery(queryText) + if err != nil { + t.Errorf("failed to parse query: %s", err) + return + } + + if queryText != q.Print() { + t.Errorf("string match failed: %s", q.Print()) + return + } + if !reflect.DeepEqual(expectedResult, q) { + t.Error("deepqual match failed.") + t.Error("got:") + t.Error(spew.Sdump(q)) + t.Error("expected:") + t.Error(spew.Sdump(expectedResult)) + } +} + +func TestParseQuery(t *testing.T) { + t.Parallel() + + text1 := `query test: where (bananas > 100 and monkeys.# <= 12) or not (coconuts < 10 and area not > 50) or name sameas Julian or name matches "^King " orderby name limit 10 offset 20` + result1 := New("test:").Where(Or( + And( + Where("bananas", GreaterThan, 100), + Where("monkeys.#", LessThanOrEqual, 12), + ), + Not(And( + Where("coconuts", LessThan, 10), + Not(Where("area", GreaterThan, 50)), + )), + Where("name", SameAs, "Julian"), + Where("name", Matches, "^King "), + )).OrderBy("name").Limit(10).Offset(20) + testParsing(t, text1, result1) + + testParsing(t, `query test: orderby name`, New("test:").OrderBy("name")) + testParsing(t, `query test: limit 10`, New("test:").Limit(10)) + testParsing(t, `query test: offset 10`, New("test:").Offset(10)) + testParsing(t, `query test: where banana matches ^ban`, New("test:").Where(Where("banana", Matches, "^ban"))) + testParsing(t, `query test: where banana exists`, New("test:").Where(Where("banana", Exists, nil))) + testParsing(t, `query test: where banana not exists`, New("test:").Where(Not(Where("banana", Exists, nil)))) + + // test all operators + testParsing(t, `query test: where banana == 1`, New("test:").Where(Where("banana", Equals, 1))) + testParsing(t, `query test: where banana > 1`, New("test:").Where(Where("banana", GreaterThan, 1))) + testParsing(t, `query test: where banana >= 1`, New("test:").Where(Where("banana", GreaterThanOrEqual, 1))) + testParsing(t, `query test: where banana < 1`, New("test:").Where(Where("banana", LessThan, 1))) + testParsing(t, `query test: where banana <= 1`, New("test:").Where(Where("banana", LessThanOrEqual, 1))) + testParsing(t, `query test: where banana f== 1.1`, New("test:").Where(Where("banana", FloatEquals, 1.1))) + testParsing(t, `query test: where banana f> 1.1`, New("test:").Where(Where("banana", FloatGreaterThan, 1.1))) + testParsing(t, `query test: where banana f>= 1.1`, New("test:").Where(Where("banana", FloatGreaterThanOrEqual, 1.1))) + testParsing(t, `query test: where banana f< 1.1`, New("test:").Where(Where("banana", FloatLessThan, 1.1))) + testParsing(t, `query test: where banana f<= 1.1`, New("test:").Where(Where("banana", FloatLessThanOrEqual, 1.1))) + testParsing(t, `query test: where banana sameas banana`, New("test:").Where(Where("banana", SameAs, "banana"))) + testParsing(t, `query test: where banana contains banana`, New("test:").Where(Where("banana", Contains, "banana"))) + testParsing(t, `query test: where banana startswith banana`, New("test:").Where(Where("banana", StartsWith, "banana"))) + testParsing(t, `query test: where banana endswith banana`, New("test:").Where(Where("banana", EndsWith, "banana"))) + testParsing(t, `query test: where banana in banana,coconut`, New("test:").Where(Where("banana", In, []string{"banana", "coconut"}))) + testParsing(t, `query test: where banana matches banana`, New("test:").Where(Where("banana", Matches, "banana"))) + testParsing(t, `query test: where banana is true`, New("test:").Where(Where("banana", Is, true))) + testParsing(t, `query test: where banana exists`, New("test:").Where(Where("banana", Exists, nil))) + + // special + testParsing(t, `query test: where banana not exists`, New("test:").Where(Not(Where("banana", Exists, nil)))) +} + +func testParseError(t *testing.T, queryText string, expectedErrorString string) { + t.Helper() + + _, err := ParseQuery(queryText) + if err == nil { + t.Errorf("should fail to parse: %s", queryText) + return + } + if err.Error() != expectedErrorString { + t.Errorf("unexpected error for query: %s\nwanted: %s\n got: %s", queryText, expectedErrorString, err) + } +} + +func TestParseErrors(t *testing.T) { + t.Parallel() + + // syntax + testParseError(t, `query`, `unexpected end at position 5`) + testParseError(t, `query test: where`, `unexpected end at position 17`) + testParseError(t, `query test: where (`, `unexpected end at position 19`) + testParseError(t, `query test: where )`, `unknown clause ")" at position 19`) + testParseError(t, `query test: where not`, `unexpected end at position 21`) + testParseError(t, `query test: where banana`, `unexpected end at position 24`) + testParseError(t, `query test: where banana >`, `unexpected end at position 26`) + testParseError(t, `query test: where banana nope`, `unknown operator at position 26`) + testParseError(t, `query test: where banana exists or`, `unexpected end at position 34`) + testParseError(t, `query test: where banana exists and`, `unexpected end at position 35`) + testParseError(t, `query test: where banana exists and (`, `unexpected end at position 37`) + testParseError(t, `query test: where banana exists and banana is true or`, `you may not mix "and" and "or" (position: 52)`) + testParseError(t, `query test: where banana exists or banana is true and`, `you may not mix "and" and "or" (position: 51)`) + // testParseError(t, `query test: where banana exists and (`, ``) + + // value parsing error + testParseError(t, `query test: where banana == banana`, `could not parse banana to int64: strconv.ParseInt: parsing "banana": invalid syntax (hint: use "sameas" to compare strings)`) + testParseError(t, `query test: where banana f== banana`, `could not parse banana to float64: strconv.ParseFloat: parsing "banana": invalid syntax`) + testParseError(t, `query test: where banana in banana`, `could not parse "banana" to []string`) + testParseError(t, `query test: where banana matches [banana`, "could not compile regex \"[banana\": error parsing regexp: missing closing ]: `[banana`") + testParseError(t, `query test: where banana is great`, `could not parse "great" to bool: strconv.ParseBool: parsing "great": invalid syntax`) +} diff --git a/base/database/query/query.go b/base/database/query/query.go new file mode 100644 index 000000000..d1be1408c --- /dev/null +++ b/base/database/query/query.go @@ -0,0 +1,170 @@ +package query + +import ( + "fmt" + "strings" + + "github.com/safing/portmaster/base/database/accessor" + "github.com/safing/portmaster/base/database/record" +) + +// Example: +// q.New("core:/", +// q.Where("a", q.GreaterThan, 0), +// q.Where("b", q.Equals, 0), +// q.Or( +// q.Where("c", q.StartsWith, "x"), +// q.Where("d", q.Contains, "y") +// ) +// ) + +// Query contains a compiled query. +type Query struct { + checked bool + dbName string + dbKeyPrefix string + where Condition + orderBy string + limit int + offset int +} + +// New creates a new query with the supplied prefix. +func New(prefix string) *Query { + dbName, dbKeyPrefix := record.ParseKey(prefix) + return &Query{ + dbName: dbName, + dbKeyPrefix: dbKeyPrefix, + } +} + +// Where adds filtering. +func (q *Query) Where(condition Condition) *Query { + q.where = condition + return q +} + +// Limit limits the number of returned results. +func (q *Query) Limit(limit int) *Query { + q.limit = limit + return q +} + +// Offset sets the query offset. +func (q *Query) Offset(offset int) *Query { + q.offset = offset + return q +} + +// OrderBy orders the results by the given key. +func (q *Query) OrderBy(key string) *Query { + q.orderBy = key + return q +} + +// Check checks for errors in the query. +func (q *Query) Check() (*Query, error) { + if q.checked { + return q, nil + } + + // check condition + if q.where != nil { + err := q.where.check() + if err != nil { + return nil, err + } + } + + q.checked = true + return q, nil +} + +// MustBeValid checks for errors in the query and panics if there is an error. +func (q *Query) MustBeValid() *Query { + _, err := q.Check() + if err != nil { + panic(err) + } + return q +} + +// IsChecked returns whether they query was checked. +func (q *Query) IsChecked() bool { + return q.checked +} + +// MatchesKey checks whether the query matches the supplied database key (key without database prefix). +func (q *Query) MatchesKey(dbKey string) bool { + return strings.HasPrefix(dbKey, q.dbKeyPrefix) +} + +// MatchesRecord checks whether the query matches the supplied database record (value only). +func (q *Query) MatchesRecord(r record.Record) bool { + if q.where == nil { + return true + } + + acc := r.GetAccessor(r) + if acc == nil { + return false + } + return q.where.complies(acc) +} + +// MatchesAccessor checks whether the query matches the supplied accessor (value only). +func (q *Query) MatchesAccessor(acc accessor.Accessor) bool { + if q.where == nil { + return true + } + return q.where.complies(acc) +} + +// Matches checks whether the query matches the supplied database record. +func (q *Query) Matches(r record.Record) bool { + if !q.MatchesKey(r.DatabaseKey()) { + return false + } + return q.MatchesRecord(r) +} + +// Print returns the string representation of the query. +func (q *Query) Print() string { + var where string + if q.where != nil { + where = q.where.string() + if where != "" { + if strings.HasPrefix(where, "(") { + where = where[1 : len(where)-1] + } + where = fmt.Sprintf(" where %s", where) + } + } + + var orderBy string + if q.orderBy != "" { + orderBy = fmt.Sprintf(" orderby %s", q.orderBy) + } + + var limit string + if q.limit > 0 { + limit = fmt.Sprintf(" limit %d", q.limit) + } + + var offset string + if q.offset > 0 { + offset = fmt.Sprintf(" offset %d", q.offset) + } + + return fmt.Sprintf("query %s:%s%s%s%s%s", q.dbName, q.dbKeyPrefix, where, orderBy, limit, offset) +} + +// DatabaseName returns the name of the database. +func (q *Query) DatabaseName() string { + return q.dbName +} + +// DatabaseKeyPrefix returns the key prefix for the database. +func (q *Query) DatabaseKeyPrefix() string { + return q.dbKeyPrefix +} diff --git a/base/database/query/query_test.go b/base/database/query/query_test.go new file mode 100644 index 000000000..402dc6151 --- /dev/null +++ b/base/database/query/query_test.go @@ -0,0 +1,113 @@ +//nolint:unparam +package query + +import ( + "testing" + + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/formats/dsd" +) + +// copied from https://github.com/tidwall/gjson/blob/master/gjson_test.go +var testJSON = `{"age":100, "name":{"here":"B\\\"R"}, + "noop":{"what is a wren?":"a bird"}, + "happy":true,"immortal":false, + "items":[1,2,3,{"tags":[1,2,3],"points":[[1,2],[3,4]]},4,5,6,7], + "arr":["1",2,"3",{"hello":"world"},"4",5], + "vals":[1,2,3,{"sadf":sdf"asdf"}],"name":{"first":"tom","last":null}, + "created":"2014-05-16T08:28:06.989Z", + "loggy":{ + "programmers": [ + { + "firstName": "Brett", + "lastName": "McLaughlin", + "email": "aaaa", + "tag": "good" + }, + { + "firstName": "Jason", + "lastName": "Hunter", + "email": "bbbb", + "tag": "bad" + }, + { + "firstName": "Elliotte", + "lastName": "Harold", + "email": "cccc", + "tag":, "good" + }, + { + "firstName": 1002.3, + "age": 101 + } + ] + }, + "lastly":{"yay":"final"}, + "temperature": 120.413 +}` + +func testQuery(t *testing.T, r record.Record, shouldMatch bool, condition Condition) { + t.Helper() + + q := New("test:").Where(condition).MustBeValid() + // fmt.Printf("%s\n", q.Print()) + + matched := q.Matches(r) + switch { + case !matched && shouldMatch: + t.Errorf("should match: %s", q.Print()) + case matched && !shouldMatch: + t.Errorf("should not match: %s", q.Print()) + } +} + +func TestQuery(t *testing.T) { + t.Parallel() + + // if !gjson.Valid(testJSON) { + // t.Fatal("test json is invalid") + // } + r, err := record.NewWrapper("", nil, dsd.JSON, []byte(testJSON)) + if err != nil { + t.Fatal(err) + } + + testQuery(t, r, true, Where("age", Equals, 100)) + testQuery(t, r, true, Where("age", GreaterThan, uint8(99))) + testQuery(t, r, true, Where("age", GreaterThanOrEqual, 99)) + testQuery(t, r, true, Where("age", GreaterThanOrEqual, 100)) + testQuery(t, r, true, Where("age", LessThan, 101)) + testQuery(t, r, true, Where("age", LessThanOrEqual, "101")) + testQuery(t, r, true, Where("age", LessThanOrEqual, 100)) + + testQuery(t, r, true, Where("temperature", FloatEquals, 120.413)) + testQuery(t, r, true, Where("temperature", FloatGreaterThan, 120)) + testQuery(t, r, true, Where("temperature", FloatGreaterThanOrEqual, 120)) + testQuery(t, r, true, Where("temperature", FloatGreaterThanOrEqual, 120.413)) + testQuery(t, r, true, Where("temperature", FloatLessThan, 121)) + testQuery(t, r, true, Where("temperature", FloatLessThanOrEqual, "121")) + testQuery(t, r, true, Where("temperature", FloatLessThanOrEqual, "120.413")) + + testQuery(t, r, true, Where("lastly.yay", SameAs, "final")) + testQuery(t, r, true, Where("lastly.yay", Contains, "ina")) + testQuery(t, r, true, Where("lastly.yay", StartsWith, "fin")) + testQuery(t, r, true, Where("lastly.yay", EndsWith, "nal")) + testQuery(t, r, true, Where("lastly.yay", In, "draft,final")) + testQuery(t, r, true, Where("lastly.yay", In, "final,draft")) + + testQuery(t, r, true, Where("happy", Is, true)) + testQuery(t, r, true, Where("happy", Is, "true")) + testQuery(t, r, true, Where("happy", Is, "t")) + testQuery(t, r, true, Not(Where("happy", Is, "0"))) + testQuery(t, r, true, And( + Where("happy", Is, "1"), + Not(Or( + Where("happy", Is, false), + Where("happy", Is, "f"), + )), + )) + + testQuery(t, r, true, Where("happy", Exists, nil)) + + testQuery(t, r, true, Where("created", Matches, "^2014-[0-9]{2}-[0-9]{2}T")) +} diff --git a/base/database/record/base.go b/base/database/record/base.go new file mode 100644 index 000000000..deacd78bf --- /dev/null +++ b/base/database/record/base.go @@ -0,0 +1,156 @@ +package record + +import ( + "errors" + + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/database/accessor" + "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/base/log" +) + +// TODO(ppacher): +// we can reduce the record.Record interface a lot by moving +// most of those functions that require the Record as it's first +// parameter to static package functions +// (i.e. Marshal, MarshalRecord, GetAccessor, ...). +// We should also consider given Base a GetBase() *Base method +// that returns itself. This way we can remove almost all Base +// only methods from the record.Record interface. That is, we can +// remove all those CreateMeta, UpdateMeta, ... stuff from the +// interface definition (not the actual functions!). This would make +// the record.Record interface slim and only provide methods that +// most users actually need. All those database/storage related methods +// can still be accessed by using GetBase().XXX() instead. We can also +// expose the dbName and dbKey and meta properties directly which would +// make a nice JSON blob when marshalled. + +// Base provides a quick way to comply with the Model interface. +type Base struct { + dbName string + dbKey string + meta *Meta +} + +// SetKey sets the key on the database record. The key may only be set once and +// future calls to SetKey will be ignored. If you want to copy/move the record +// to another database key, you will need to create a copy and assign a new key. +// A key must be set before the record is used in any database operation. +func (b *Base) SetKey(key string) { + if !b.KeyIsSet() { + b.dbName, b.dbKey = ParseKey(key) + } else { + log.Errorf("database: key is already set: tried to replace %q with %q", b.Key(), key) + } +} + +// ResetKey resets the database name and key. +// Use with caution! +func (b *Base) ResetKey() { + b.dbName = "" + b.dbKey = "" +} + +// Key returns the key of the database record. +// As the key must be set before any usage and can only be set once, this +// function may be used without locking the record. +func (b *Base) Key() string { + return b.dbName + ":" + b.dbKey +} + +// KeyIsSet returns true if the database key is set. +// As the key must be set before any usage and can only be set once, this +// function may be used without locking the record. +func (b *Base) KeyIsSet() bool { + return b.dbName != "" +} + +// DatabaseName returns the name of the database. +// As the key must be set before any usage and can only be set once, this +// function may be used without locking the record. +func (b *Base) DatabaseName() string { + return b.dbName +} + +// DatabaseKey returns the database key of the database record. +// As the key must be set before any usage and can only be set once, this +// function may be used without locking the record. +func (b *Base) DatabaseKey() string { + return b.dbKey +} + +// Meta returns the metadata object for this record. +func (b *Base) Meta() *Meta { + return b.meta +} + +// CreateMeta sets a default metadata object for this record. +func (b *Base) CreateMeta() { + b.meta = &Meta{} +} + +// UpdateMeta creates the metadata if it does not exist and updates it. +func (b *Base) UpdateMeta() { + if b.meta == nil { + b.CreateMeta() + } + b.meta.Update() +} + +// SetMeta sets the metadata on the database record, it should only be called after loading the record. Use MoveTo to save the record with another key. +func (b *Base) SetMeta(meta *Meta) { + b.meta = meta +} + +// Marshal marshals the object, without the database key or metadata. It returns nil if the record is deleted. +func (b *Base) Marshal(self Record, format uint8) ([]byte, error) { + if b.Meta() == nil { + return nil, errors.New("missing meta") + } + + if b.Meta().Deleted > 0 { + return nil, nil + } + + dumped, err := dsd.Dump(self, format) + if err != nil { + return nil, err + } + return dumped, nil +} + +// MarshalRecord packs the object, including metadata, into a byte array for saving in a database. +func (b *Base) MarshalRecord(self Record) ([]byte, error) { + if b.Meta() == nil { + return nil, errors.New("missing meta") + } + + // version + c := container.New([]byte{1}) + + // meta encoding + metaSection, err := dsd.Dump(b.meta, dsd.GenCode) + if err != nil { + return nil, err + } + c.AppendAsBlock(metaSection) + + // data + dataSection, err := b.Marshal(self, dsd.JSON) + if err != nil { + return nil, err + } + c.Append(dataSection) + + return c.CompileData(), nil +} + +// IsWrapped returns whether the record is a Wrapper. +func (b *Base) IsWrapped() bool { + return false +} + +// GetAccessor returns an accessor for this record, if available. +func (b *Base) GetAccessor(self Record) accessor.Accessor { + return accessor.NewStructAccessor(self) +} diff --git a/base/database/record/base_test.go b/base/database/record/base_test.go new file mode 100644 index 000000000..2f1521176 --- /dev/null +++ b/base/database/record/base_test.go @@ -0,0 +1,13 @@ +package record + +import "testing" + +func TestBaseRecord(t *testing.T) { + t.Parallel() + + // check model interface compliance + var m Record + b := &TestRecord{} + m = b + _ = m +} diff --git a/base/database/record/key.go b/base/database/record/key.go new file mode 100644 index 000000000..0dfb0a331 --- /dev/null +++ b/base/database/record/key.go @@ -0,0 +1,14 @@ +package record + +import ( + "strings" +) + +// ParseKey splits a key into it's database name and key parts. +func ParseKey(key string) (dbName, dbKey string) { + splitted := strings.SplitN(key, ":", 2) + if len(splitted) < 2 { + return splitted[0], "" + } + return splitted[0], strings.Join(splitted[1:], ":") +} diff --git a/base/database/record/meta-bench_test.go b/base/database/record/meta-bench_test.go new file mode 100644 index 000000000..f3c048542 --- /dev/null +++ b/base/database/record/meta-bench_test.go @@ -0,0 +1,348 @@ +package record + +// Benchmark: +// BenchmarkAllocateBytes-8 2000000000 0.76 ns/op +// BenchmarkAllocateStruct1-8 2000000000 0.76 ns/op +// BenchmarkAllocateStruct2-8 2000000000 0.79 ns/op +// BenchmarkMetaSerializeContainer-8 1000000 1703 ns/op +// BenchmarkMetaUnserializeContainer-8 2000000 950 ns/op +// BenchmarkMetaSerializeVarInt-8 3000000 457 ns/op +// BenchmarkMetaUnserializeVarInt-8 20000000 62.9 ns/op +// BenchmarkMetaSerializeWithXDR2-8 1000000 2360 ns/op +// BenchmarkMetaUnserializeWithXDR2-8 500000 3189 ns/op +// BenchmarkMetaSerializeWithColfer-8 10000000 237 ns/op +// BenchmarkMetaUnserializeWithColfer-8 20000000 51.7 ns/op +// BenchmarkMetaSerializeWithCodegen-8 50000000 23.7 ns/op +// BenchmarkMetaUnserializeWithCodegen-8 100000000 18.9 ns/op +// BenchmarkMetaSerializeWithDSDJSON-8 1000000 2398 ns/op +// BenchmarkMetaUnserializeWithDSDJSON-8 300000 6264 ns/op + +import ( + "testing" + "time" + + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/base/formats/varint" +) + +var testMeta = &Meta{ + Created: time.Now().Unix(), + Modified: time.Now().Unix(), + Expires: time.Now().Unix(), + Deleted: time.Now().Unix(), + secret: true, + cronjewel: true, +} + +func BenchmarkAllocateBytes(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = make([]byte, 33) + } +} + +func BenchmarkAllocateStruct1(b *testing.B) { + for i := 0; i < b.N; i++ { + var newMeta Meta + _ = newMeta + } +} + +func BenchmarkAllocateStruct2(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = Meta{} + } +} + +func BenchmarkMetaSerializeContainer(b *testing.B) { + // Start benchmark + for i := 0; i < b.N; i++ { + c := container.New() + c.AppendNumber(uint64(testMeta.Created)) + c.AppendNumber(uint64(testMeta.Modified)) + c.AppendNumber(uint64(testMeta.Expires)) + c.AppendNumber(uint64(testMeta.Deleted)) + switch { + case testMeta.secret && testMeta.cronjewel: + c.AppendNumber(3) + case testMeta.secret: + c.AppendNumber(1) + case testMeta.cronjewel: + c.AppendNumber(2) + default: + c.AppendNumber(0) + } + } +} + +func BenchmarkMetaUnserializeContainer(b *testing.B) { + // Setup + c := container.New() + c.AppendNumber(uint64(testMeta.Created)) + c.AppendNumber(uint64(testMeta.Modified)) + c.AppendNumber(uint64(testMeta.Expires)) + c.AppendNumber(uint64(testMeta.Deleted)) + switch { + case testMeta.secret && testMeta.cronjewel: + c.AppendNumber(3) + case testMeta.secret: + c.AppendNumber(1) + case testMeta.cronjewel: + c.AppendNumber(2) + default: + c.AppendNumber(0) + } + encodedData := c.CompileData() + + // Reset timer for precise results + b.ResetTimer() + + // Start benchmark + for i := 0; i < b.N; i++ { + var newMeta Meta + var err error + var num uint64 + c := container.New(encodedData) + num, err = c.GetNextN64() + newMeta.Created = int64(num) + if err != nil { + b.Errorf("could not decode: %s", err) + return + } + num, err = c.GetNextN64() + newMeta.Modified = int64(num) + if err != nil { + b.Errorf("could not decode: %s", err) + return + } + num, err = c.GetNextN64() + newMeta.Expires = int64(num) + if err != nil { + b.Errorf("could not decode: %s", err) + return + } + num, err = c.GetNextN64() + newMeta.Deleted = int64(num) + if err != nil { + b.Errorf("could not decode: %s", err) + return + } + + flags, err := c.GetNextN8() + if err != nil { + b.Errorf("could not decode: %s", err) + return + } + + switch flags { + case 3: + newMeta.secret = true + newMeta.cronjewel = true + case 2: + newMeta.cronjewel = true + case 1: + newMeta.secret = true + case 0: + default: + b.Errorf("invalid flag value: %d", flags) + return + } + } +} + +func BenchmarkMetaSerializeVarInt(b *testing.B) { + // Start benchmark + for i := 0; i < b.N; i++ { + encoded := make([]byte, 33) + offset := 0 + data := varint.Pack64(uint64(testMeta.Created)) + for _, part := range data { + encoded[offset] = part + offset++ + } + data = varint.Pack64(uint64(testMeta.Modified)) + for _, part := range data { + encoded[offset] = part + offset++ + } + data = varint.Pack64(uint64(testMeta.Expires)) + for _, part := range data { + encoded[offset] = part + offset++ + } + data = varint.Pack64(uint64(testMeta.Deleted)) + for _, part := range data { + encoded[offset] = part + offset++ + } + + switch { + case testMeta.secret && testMeta.cronjewel: + encoded[offset] = 3 + case testMeta.secret: + encoded[offset] = 1 + case testMeta.cronjewel: + encoded[offset] = 2 + default: + encoded[offset] = 0 + } + } +} + +func BenchmarkMetaUnserializeVarInt(b *testing.B) { + // Setup + encoded := make([]byte, 33) + offset := 0 + data := varint.Pack64(uint64(testMeta.Created)) + for _, part := range data { + encoded[offset] = part + offset++ + } + data = varint.Pack64(uint64(testMeta.Modified)) + for _, part := range data { + encoded[offset] = part + offset++ + } + data = varint.Pack64(uint64(testMeta.Expires)) + for _, part := range data { + encoded[offset] = part + offset++ + } + data = varint.Pack64(uint64(testMeta.Deleted)) + for _, part := range data { + encoded[offset] = part + offset++ + } + + switch { + case testMeta.secret && testMeta.cronjewel: + encoded[offset] = 3 + case testMeta.secret: + encoded[offset] = 1 + case testMeta.cronjewel: + encoded[offset] = 2 + default: + encoded[offset] = 0 + } + offset++ + encodedData := encoded[:offset] + + // Reset timer for precise results + b.ResetTimer() + + // Start benchmark + for i := 0; i < b.N; i++ { + var newMeta Meta + offset = 0 + + num, n, err := varint.Unpack64(encodedData) + if err != nil { + b.Error(err) + return + } + testMeta.Created = int64(num) + offset += n + + num, n, err = varint.Unpack64(encodedData[offset:]) + if err != nil { + b.Error(err) + return + } + testMeta.Modified = int64(num) + offset += n + + num, n, err = varint.Unpack64(encodedData[offset:]) + if err != nil { + b.Error(err) + return + } + testMeta.Expires = int64(num) + offset += n + + num, n, err = varint.Unpack64(encodedData[offset:]) + if err != nil { + b.Error(err) + return + } + testMeta.Deleted = int64(num) + offset += n + + switch encodedData[offset] { + case 3: + newMeta.secret = true + newMeta.cronjewel = true + case 2: + newMeta.cronjewel = true + case 1: + newMeta.secret = true + case 0: + default: + b.Errorf("invalid flag value: %d", encodedData[offset]) + return + } + } +} + +func BenchmarkMetaSerializeWithCodegen(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := testMeta.GenCodeMarshal(nil) + if err != nil { + b.Errorf("failed to serialize with codegen: %s", err) + return + } + } +} + +func BenchmarkMetaUnserializeWithCodegen(b *testing.B) { + // Setup + encodedData, err := testMeta.GenCodeMarshal(nil) + if err != nil { + b.Errorf("failed to serialize with codegen: %s", err) + return + } + + // Reset timer for precise results + b.ResetTimer() + + // Start benchmark + for i := 0; i < b.N; i++ { + var newMeta Meta + _, err := newMeta.GenCodeUnmarshal(encodedData) + if err != nil { + b.Errorf("failed to unserialize with codegen: %s", err) + return + } + } +} + +func BenchmarkMetaSerializeWithDSDJSON(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := dsd.Dump(testMeta, dsd.JSON) + if err != nil { + b.Errorf("failed to serialize with DSD/JSON: %s", err) + return + } + } +} + +func BenchmarkMetaUnserializeWithDSDJSON(b *testing.B) { + // Setup + encodedData, err := dsd.Dump(testMeta, dsd.JSON) + if err != nil { + b.Errorf("failed to serialize with DSD/JSON: %s", err) + return + } + + // Reset timer for precise results + b.ResetTimer() + + // Start benchmark + for i := 0; i < b.N; i++ { + var newMeta Meta + _, err := dsd.Load(encodedData, &newMeta) + if err != nil { + b.Errorf("failed to unserialize with DSD/JSON: %s", err) + return + } + } +} diff --git a/base/database/record/meta-gencode.go b/base/database/record/meta-gencode.go new file mode 100644 index 000000000..180e7fa92 --- /dev/null +++ b/base/database/record/meta-gencode.go @@ -0,0 +1,145 @@ +package record + +import ( + "fmt" +) + +// GenCodeSize returns the size of the gencode marshalled byte slice. +func (m *Meta) GenCodeSize() (s int) { + s += 34 + return +} + +// GenCodeMarshal gencode marshalls Meta into the given byte array, or a new one if its too small. +func (m *Meta) GenCodeMarshal(buf []byte) ([]byte, error) { + size := m.GenCodeSize() + { + if cap(buf) >= size { + buf = buf[:size] + } else { + buf = make([]byte, size) + } + } + i := uint64(0) + + { + + buf[0+0] = byte(m.Created >> 0) + + buf[1+0] = byte(m.Created >> 8) + + buf[2+0] = byte(m.Created >> 16) + + buf[3+0] = byte(m.Created >> 24) + + buf[4+0] = byte(m.Created >> 32) + + buf[5+0] = byte(m.Created >> 40) + + buf[6+0] = byte(m.Created >> 48) + + buf[7+0] = byte(m.Created >> 56) + + } + { + + buf[0+8] = byte(m.Modified >> 0) + + buf[1+8] = byte(m.Modified >> 8) + + buf[2+8] = byte(m.Modified >> 16) + + buf[3+8] = byte(m.Modified >> 24) + + buf[4+8] = byte(m.Modified >> 32) + + buf[5+8] = byte(m.Modified >> 40) + + buf[6+8] = byte(m.Modified >> 48) + + buf[7+8] = byte(m.Modified >> 56) + + } + { + + buf[0+16] = byte(m.Expires >> 0) + + buf[1+16] = byte(m.Expires >> 8) + + buf[2+16] = byte(m.Expires >> 16) + + buf[3+16] = byte(m.Expires >> 24) + + buf[4+16] = byte(m.Expires >> 32) + + buf[5+16] = byte(m.Expires >> 40) + + buf[6+16] = byte(m.Expires >> 48) + + buf[7+16] = byte(m.Expires >> 56) + + } + { + + buf[0+24] = byte(m.Deleted >> 0) + + buf[1+24] = byte(m.Deleted >> 8) + + buf[2+24] = byte(m.Deleted >> 16) + + buf[3+24] = byte(m.Deleted >> 24) + + buf[4+24] = byte(m.Deleted >> 32) + + buf[5+24] = byte(m.Deleted >> 40) + + buf[6+24] = byte(m.Deleted >> 48) + + buf[7+24] = byte(m.Deleted >> 56) + + } + { + if m.secret { + buf[32] = 1 + } else { + buf[32] = 0 + } + } + { + if m.cronjewel { + buf[33] = 1 + } else { + buf[33] = 0 + } + } + return buf[:i+34], nil +} + +// GenCodeUnmarshal gencode unmarshalls Meta and returns the bytes read. +func (m *Meta) GenCodeUnmarshal(buf []byte) (uint64, error) { + if len(buf) < m.GenCodeSize() { + return 0, fmt.Errorf("insufficient data: got %d out of %d bytes", len(buf), m.GenCodeSize()) + } + + i := uint64(0) + + { + m.Created = 0 | (int64(buf[0+0]) << 0) | (int64(buf[1+0]) << 8) | (int64(buf[2+0]) << 16) | (int64(buf[3+0]) << 24) | (int64(buf[4+0]) << 32) | (int64(buf[5+0]) << 40) | (int64(buf[6+0]) << 48) | (int64(buf[7+0]) << 56) + } + { + m.Modified = 0 | (int64(buf[0+8]) << 0) | (int64(buf[1+8]) << 8) | (int64(buf[2+8]) << 16) | (int64(buf[3+8]) << 24) | (int64(buf[4+8]) << 32) | (int64(buf[5+8]) << 40) | (int64(buf[6+8]) << 48) | (int64(buf[7+8]) << 56) + } + { + m.Expires = 0 | (int64(buf[0+16]) << 0) | (int64(buf[1+16]) << 8) | (int64(buf[2+16]) << 16) | (int64(buf[3+16]) << 24) | (int64(buf[4+16]) << 32) | (int64(buf[5+16]) << 40) | (int64(buf[6+16]) << 48) | (int64(buf[7+16]) << 56) + } + { + m.Deleted = 0 | (int64(buf[0+24]) << 0) | (int64(buf[1+24]) << 8) | (int64(buf[2+24]) << 16) | (int64(buf[3+24]) << 24) | (int64(buf[4+24]) << 32) | (int64(buf[5+24]) << 40) | (int64(buf[6+24]) << 48) | (int64(buf[7+24]) << 56) + } + { + m.secret = buf[32] == 1 + } + { + m.cronjewel = buf[33] == 1 + } + return i + 34, nil +} diff --git a/base/database/record/meta-gencode_test.go b/base/database/record/meta-gencode_test.go new file mode 100644 index 000000000..2de765ca6 --- /dev/null +++ b/base/database/record/meta-gencode_test.go @@ -0,0 +1,35 @@ +package record + +import ( + "reflect" + "testing" + "time" +) + +var genCodeTestMeta = &Meta{ + Created: time.Now().Unix(), + Modified: time.Now().Unix(), + Expires: time.Now().Unix(), + Deleted: time.Now().Unix(), + secret: true, + cronjewel: true, +} + +func TestGenCode(t *testing.T) { + t.Parallel() + + encoded, err := genCodeTestMeta.GenCodeMarshal(nil) + if err != nil { + t.Fatal(err) + } + + newMeta := &Meta{} + _, err = newMeta.GenCodeUnmarshal(encoded) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(genCodeTestMeta, newMeta) { + t.Errorf("objects are not equal, got: %v", newMeta) + } +} diff --git a/base/database/record/meta.colf b/base/database/record/meta.colf new file mode 100644 index 000000000..0072e92f8 --- /dev/null +++ b/base/database/record/meta.colf @@ -0,0 +1,10 @@ +package record + +type course struct { + Created int64 + Modified int64 + Expires int64 + Deleted int64 + Secret bool + Cronjewel bool +} diff --git a/base/database/record/meta.gencode b/base/database/record/meta.gencode new file mode 100644 index 000000000..7592e2d88 --- /dev/null +++ b/base/database/record/meta.gencode @@ -0,0 +1,8 @@ +struct Meta { + Created int64 + Modified int64 + Expires int64 + Deleted int64 + Secret bool + Cronjewel bool +} diff --git a/base/database/record/meta.go b/base/database/record/meta.go new file mode 100644 index 000000000..54a0e6148 --- /dev/null +++ b/base/database/record/meta.go @@ -0,0 +1,129 @@ +package record + +import "time" + +// Meta holds metadata about the record. +type Meta struct { + Created int64 + Modified int64 + Expires int64 + Deleted int64 + secret bool // secrets must not be sent to the UI, only synced between nodes + cronjewel bool // crownjewels must never leave the instance, but may be read by the UI +} + +// SetAbsoluteExpiry sets an absolute expiry time (in seconds), that is not affected when the record is updated. +func (m *Meta) SetAbsoluteExpiry(seconds int64) { + m.Expires = seconds + m.Deleted = 0 +} + +// SetRelativateExpiry sets a relative expiry time (ie. TTL in seconds) that is automatically updated whenever the record is updated/saved. +func (m *Meta) SetRelativateExpiry(seconds int64) { + if seconds >= 0 { + m.Deleted = -seconds + } +} + +// GetAbsoluteExpiry returns the absolute expiry time. +func (m *Meta) GetAbsoluteExpiry() int64 { + return m.Expires +} + +// GetRelativeExpiry returns the current relative expiry time - ie. seconds until expiry. +// A negative value signifies that the record does not expire. +func (m *Meta) GetRelativeExpiry() int64 { + if m.Expires == 0 { + return -1 + } + + abs := m.Expires - time.Now().Unix() + if abs < 0 { + return 0 + } + return abs +} + +// MakeCrownJewel marks the database records as a crownjewel, meaning that it will not be sent/synced to other devices. +func (m *Meta) MakeCrownJewel() { + m.cronjewel = true +} + +// MakeSecret sets the database record as secret, meaning that it may only be used internally, and not by interfacing processes, such as the UI. +func (m *Meta) MakeSecret() { + m.secret = true +} + +// Update updates the internal meta states and should be called before writing the record to the database. +func (m *Meta) Update() { + now := time.Now().Unix() + m.Modified = now + if m.Created == 0 { + m.Created = now + } + if m.Deleted < 0 { + m.Expires = now - m.Deleted + } +} + +// Reset resets all metadata, except for the secret and crownjewel status. +func (m *Meta) Reset() { + m.Created = 0 + m.Modified = 0 + m.Expires = 0 + m.Deleted = 0 +} + +// Delete marks the record as deleted. +func (m *Meta) Delete() { + m.Deleted = time.Now().Unix() +} + +// IsDeleted returns whether the record is deleted. +func (m *Meta) IsDeleted() bool { + return m.Deleted > 0 +} + +// CheckValidity checks whether the database record is valid. +func (m *Meta) CheckValidity() (valid bool) { + if m == nil { + return false + } + + switch { + case m.Deleted > 0: + return false + case m.Expires > 0 && m.Expires < time.Now().Unix(): + return false + default: + return true + } +} + +// CheckPermission checks whether the database record may be accessed with the following scope. +func (m *Meta) CheckPermission(local, internal bool) (permitted bool) { + if m == nil { + return false + } + + switch { + case !local && m.cronjewel: + return false + case !internal && m.secret: + return false + default: + return true + } +} + +// Duplicate returns a new copy of Meta. +func (m *Meta) Duplicate() *Meta { + return &Meta{ + Created: m.Created, + Modified: m.Modified, + Expires: m.Expires, + Deleted: m.Deleted, + secret: m.secret, + cronjewel: m.cronjewel, + } +} diff --git a/base/database/record/record.go b/base/database/record/record.go new file mode 100644 index 000000000..f18dc8985 --- /dev/null +++ b/base/database/record/record.go @@ -0,0 +1,32 @@ +package record + +import ( + "github.com/safing/portmaster/base/database/accessor" +) + +// Record provides an interface for uniformally handling database records. +type Record interface { + SetKey(key string) // test:config + Key() string // test:config + KeyIsSet() bool + DatabaseName() string // test + DatabaseKey() string // config + + // Metadata. + Meta() *Meta + SetMeta(meta *Meta) + CreateMeta() + UpdateMeta() + + // Serialization. + Marshal(self Record, format uint8) ([]byte, error) + MarshalRecord(self Record) ([]byte, error) + GetAccessor(self Record) accessor.Accessor + + // Locking. + Lock() + Unlock() + + // Wrapping. + IsWrapped() bool +} diff --git a/base/database/record/record_test.go b/base/database/record/record_test.go new file mode 100644 index 000000000..5912d6796 --- /dev/null +++ b/base/database/record/record_test.go @@ -0,0 +1,10 @@ +package record + +import ( + "sync" +) + +type TestRecord struct { + Base + sync.Mutex +} diff --git a/base/database/record/wrapper.go b/base/database/record/wrapper.go new file mode 100644 index 000000000..b8505baff --- /dev/null +++ b/base/database/record/wrapper.go @@ -0,0 +1,160 @@ +package record + +import ( + "errors" + "fmt" + "sync" + + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/database/accessor" + "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/base/formats/varint" +) + +// Wrapper wraps raw data and implements the Record interface. +type Wrapper struct { + Base + sync.Mutex + + Format uint8 + Data []byte +} + +// NewRawWrapper returns a record wrapper for the given data, including metadata. This is normally only used by storage backends when loading records. +func NewRawWrapper(database, key string, data []byte) (*Wrapper, error) { + version, offset, err := varint.Unpack8(data) + if err != nil { + return nil, err + } + if version != 1 { + return nil, fmt.Errorf("incompatible record version: %d", version) + } + + metaSection, n, err := varint.GetNextBlock(data[offset:]) + if err != nil { + return nil, fmt.Errorf("could not get meta section: %w", err) + } + offset += n + + newMeta := &Meta{} + _, err = dsd.Load(metaSection, newMeta) + if err != nil { + return nil, fmt.Errorf("could not unmarshal meta section: %w", err) + } + + var format uint8 = dsd.RAW + if !newMeta.IsDeleted() { + format, n, err = varint.Unpack8(data[offset:]) + if err != nil { + return nil, fmt.Errorf("could not get dsd format: %w", err) + } + offset += n + } + + return &Wrapper{ + Base{ + database, + key, + newMeta, + }, + sync.Mutex{}, + format, + data[offset:], + }, nil +} + +// NewWrapper returns a new record wrapper for the given data. +func NewWrapper(key string, meta *Meta, format uint8, data []byte) (*Wrapper, error) { + dbName, dbKey := ParseKey(key) + + return &Wrapper{ + Base{ + dbName: dbName, + dbKey: dbKey, + meta: meta, + }, + sync.Mutex{}, + format, + data, + }, nil +} + +// Marshal marshals the object, without the database key or metadata. +func (w *Wrapper) Marshal(r Record, format uint8) ([]byte, error) { + if w.Meta() == nil { + return nil, errors.New("missing meta") + } + + if w.Meta().Deleted > 0 { + return nil, nil + } + + if format != dsd.AUTO && format != w.Format { + return nil, errors.New("could not dump model, wrapped object format mismatch") + } + + data := make([]byte, len(w.Data)+1) + data[0] = w.Format + copy(data[1:], w.Data) + + return data, nil +} + +// MarshalRecord packs the object, including metadata, into a byte array for saving in a database. +func (w *Wrapper) MarshalRecord(r Record) ([]byte, error) { + // Duplication necessary, as the version from Base would call Base.Marshal instead of Wrapper.Marshal + + if w.Meta() == nil { + return nil, errors.New("missing meta") + } + + // version + c := container.New([]byte{1}) + + // meta + metaSection, err := dsd.Dump(w.meta, dsd.GenCode) + if err != nil { + return nil, err + } + c.AppendAsBlock(metaSection) + + // data + dataSection, err := w.Marshal(r, dsd.AUTO) + if err != nil { + return nil, err + } + c.Append(dataSection) + + return c.CompileData(), nil +} + +// IsWrapped returns whether the record is a Wrapper. +func (w *Wrapper) IsWrapped() bool { + return true +} + +// Unwrap unwraps data into a record. +func Unwrap(wrapped, r Record) error { + wrapper, ok := wrapped.(*Wrapper) + if !ok { + return fmt.Errorf("cannot unwrap %T", wrapped) + } + + err := dsd.LoadAsFormat(wrapper.Data, wrapper.Format, r) + if err != nil { + return fmt.Errorf("failed to unwrap %T: %w", r, err) + } + + r.SetKey(wrapped.Key()) + r.SetMeta(wrapped.Meta()) + + return nil +} + +// GetAccessor returns an accessor for this record, if available. +func (w *Wrapper) GetAccessor(self Record) accessor.Accessor { + if w.Format == dsd.JSON && len(w.Data) > 0 { + return accessor.NewJSONBytesAccessor(&w.Data) + } + return nil +} diff --git a/base/database/record/wrapper_test.go b/base/database/record/wrapper_test.go new file mode 100644 index 000000000..5db3b01d3 --- /dev/null +++ b/base/database/record/wrapper_test.go @@ -0,0 +1,57 @@ +package record + +import ( + "bytes" + "testing" + + "github.com/safing/portmaster/base/formats/dsd" +) + +func TestWrapper(t *testing.T) { + t.Parallel() + + // check model interface compliance + var m Record + w := &Wrapper{} + m = w + _ = m + + // create test data + testData := []byte(`{"a": "b"}`) + encodedTestData := []byte(`J{"a": "b"}`) + + // test wrapper + wrapper, err := NewWrapper("test:a", &Meta{}, dsd.JSON, testData) + if err != nil { + t.Fatal(err) + } + if wrapper.Format != dsd.JSON { + t.Error("format mismatch") + } + if !bytes.Equal(testData, wrapper.Data) { + t.Error("data mismatch") + } + + encoded, err := wrapper.Marshal(wrapper, dsd.JSON) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(encodedTestData, encoded) { + t.Error("marshal mismatch") + } + + wrapper.SetMeta(&Meta{}) + wrapper.meta.Update() + raw, err := wrapper.MarshalRecord(wrapper) + if err != nil { + t.Fatal(err) + } + + wrapper2, err := NewRawWrapper("test", "a", raw) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(testData, wrapper2.Data) { + t.Error("marshal mismatch") + } +} diff --git a/base/database/registry.go b/base/database/registry.go new file mode 100644 index 000000000..44dcbae73 --- /dev/null +++ b/base/database/registry.go @@ -0,0 +1,168 @@ +package database + +import ( + "encoding/json" + "errors" + "fmt" + "io/fs" + "os" + "path" + "regexp" + "sync" + "time" + + "github.com/tevino/abool" +) + +const ( + registryFileName = "databases.json" +) + +var ( + registryPersistence = abool.NewBool(false) + writeRegistrySoon = abool.NewBool(false) + + registry = make(map[string]*Database) + registryLock sync.Mutex + + nameConstraint = regexp.MustCompile("^[A-Za-z0-9_-]{3,}$") +) + +// Register registers a new database. +// If the database is already registered, only +// the description and the primary API will be +// updated and the effective object will be returned. +func Register(db *Database) (*Database, error) { + if !initialized.IsSet() { + return nil, errors.New("database not initialized") + } + + registryLock.Lock() + defer registryLock.Unlock() + + registeredDB, ok := registry[db.Name] + save := false + + if ok { + // update database + if registeredDB.Description != db.Description { + registeredDB.Description = db.Description + save = true + } + if registeredDB.ShadowDelete != db.ShadowDelete { + registeredDB.ShadowDelete = db.ShadowDelete + save = true + } + } else { + // register new database + if !nameConstraint.MatchString(db.Name) { + return nil, errors.New("database name must only contain alphanumeric and `_-` characters and must be at least 3 characters long") + } + + now := time.Now().Round(time.Second) + db.Registered = now + db.LastUpdated = now + db.LastLoaded = time.Time{} + + registry[db.Name] = db + save = true + } + + if save && registryPersistence.IsSet() { + if ok { + registeredDB.Updated() + } + err := saveRegistry(false) + if err != nil { + return nil, err + } + } + + if ok { + return registeredDB, nil + } + return nil, nil +} + +func getDatabase(name string) (*Database, error) { + registryLock.Lock() + defer registryLock.Unlock() + + registeredDB, ok := registry[name] + if !ok { + return nil, fmt.Errorf(`database "%s" not registered`, name) + } + if time.Now().Add(-24 * time.Hour).After(registeredDB.LastLoaded) { + writeRegistrySoon.Set() + } + registeredDB.Loaded() + + return registeredDB, nil +} + +// EnableRegistryPersistence enables persistence of the database registry. +func EnableRegistryPersistence() { + if registryPersistence.SetToIf(false, true) { + // start registry writer + go registryWriter() + // TODO: make an initial write if database system is already initialized + } +} + +func loadRegistry() error { + registryLock.Lock() + defer registryLock.Unlock() + + // read file + filePath := path.Join(rootStructure.Path, registryFileName) + data, err := os.ReadFile(filePath) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return nil + } + return err + } + + // parse + databases := make(map[string]*Database) + err = json.Unmarshal(data, &databases) + if err != nil { + return err + } + + // set + registry = databases + return nil +} + +func saveRegistry(lock bool) error { + if lock { + registryLock.Lock() + defer registryLock.Unlock() + } + + // marshal + data, err := json.MarshalIndent(registry, "", "\t") + if err != nil { + return err + } + + // write file + // TODO: write atomically (best effort) + filePath := path.Join(rootStructure.Path, registryFileName) + return os.WriteFile(filePath, data, 0o0600) +} + +func registryWriter() { + for { + select { + case <-time.After(1 * time.Hour): + if writeRegistrySoon.SetToIf(true, false) { + _ = saveRegistry(true) + } + case <-shutdownSignal: + _ = saveRegistry(true) + return + } + } +} diff --git a/base/database/storage/badger/badger.go b/base/database/storage/badger/badger.go new file mode 100644 index 000000000..fbdc44a8e --- /dev/null +++ b/base/database/storage/badger/badger.go @@ -0,0 +1,231 @@ +package badger + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/dgraph-io/badger" + + "github.com/safing/portmaster/base/database/iterator" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/database/storage" + "github.com/safing/portmaster/base/log" +) + +// Badger database made pluggable for portbase. +type Badger struct { + name string + db *badger.DB +} + +func init() { + _ = storage.Register("badger", NewBadger) +} + +// NewBadger opens/creates a badger database. +func NewBadger(name, location string) (storage.Interface, error) { + opts := badger.DefaultOptions(location) + + db, err := badger.Open(opts) + if errors.Is(err, badger.ErrTruncateNeeded) { + // clean up after crash + log.Warningf("database/storage: truncating corrupted value log of badger database %s: this may cause data loss", name) + opts.Truncate = true + db, err = badger.Open(opts) + } + if err != nil { + return nil, err + } + + return &Badger{ + name: name, + db: db, + }, nil +} + +// Get returns a database record. +func (b *Badger) Get(key string) (record.Record, error) { + var item *badger.Item + + err := b.db.View(func(txn *badger.Txn) error { + var err error + item, err = txn.Get([]byte(key)) + if err != nil { + if errors.Is(err, badger.ErrKeyNotFound) { + return storage.ErrNotFound + } + return err + } + return nil + }) + if err != nil { + return nil, err + } + + // return err if deleted or expired + if item.IsDeletedOrExpired() { + return nil, storage.ErrNotFound + } + + data, err := item.ValueCopy(nil) + if err != nil { + return nil, err + } + + m, err := record.NewRawWrapper(b.name, string(item.Key()), data) + if err != nil { + return nil, err + } + return m, nil +} + +// GetMeta returns the metadata of a database record. +func (b *Badger) GetMeta(key string) (*record.Meta, error) { + // TODO: Replace with more performant variant. + + r, err := b.Get(key) + if err != nil { + return nil, err + } + + return r.Meta(), nil +} + +// Put stores a record in the database. +func (b *Badger) Put(r record.Record) (record.Record, error) { + data, err := r.MarshalRecord(r) + if err != nil { + return nil, err + } + + err = b.db.Update(func(txn *badger.Txn) error { + return txn.Set([]byte(r.DatabaseKey()), data) + }) + if err != nil { + return nil, err + } + return r, nil +} + +// Delete deletes a record from the database. +func (b *Badger) Delete(key string) error { + return b.db.Update(func(txn *badger.Txn) error { + err := txn.Delete([]byte(key)) + if err != nil && !errors.Is(err, badger.ErrKeyNotFound) { + return err + } + return nil + }) +} + +// Query returns a an iterator for the supplied query. +func (b *Badger) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { + _, err := q.Check() + if err != nil { + return nil, fmt.Errorf("invalid query: %w", err) + } + + queryIter := iterator.New() + + go b.queryExecutor(queryIter, q, local, internal) + return queryIter, nil +} + +//nolint:gocognit +func (b *Badger) queryExecutor(queryIter *iterator.Iterator, q *query.Query, local, internal bool) { + err := b.db.View(func(txn *badger.Txn) error { + it := txn.NewIterator(badger.DefaultIteratorOptions) + defer it.Close() + prefix := []byte(q.DatabaseKeyPrefix()) + for it.Seek(prefix); it.ValidForPrefix(prefix); it.Next() { + item := it.Item() + + var data []byte + err := item.Value(func(val []byte) error { + data = val + return nil + }) + if err != nil { + return err + } + + r, err := record.NewRawWrapper(b.name, string(item.Key()), data) + if err != nil { + return err + } + + if !r.Meta().CheckValidity() { + continue + } + if !r.Meta().CheckPermission(local, internal) { + continue + } + + if q.MatchesRecord(r) { + copiedData, err := item.ValueCopy(nil) + if err != nil { + return err + } + newWrapper, err := record.NewRawWrapper(b.name, r.DatabaseKey(), copiedData) + if err != nil { + return err + } + select { + case <-queryIter.Done: + return nil + case queryIter.Next <- newWrapper: + default: + select { + case queryIter.Next <- newWrapper: + case <-queryIter.Done: + return nil + case <-time.After(1 * time.Minute): + return errors.New("query timeout") + } + } + } + + } + return nil + }) + + queryIter.Finish(err) +} + +// ReadOnly returns whether the database is read only. +func (b *Badger) ReadOnly() bool { + return false +} + +// Injected returns whether the database is injected. +func (b *Badger) Injected() bool { + return false +} + +// Maintain runs a light maintenance operation on the database. +func (b *Badger) Maintain(_ context.Context) error { + _ = b.db.RunValueLogGC(0.7) + return nil +} + +// MaintainThorough runs a thorough maintenance operation on the database. +func (b *Badger) MaintainThorough(_ context.Context) (err error) { + for err == nil { + err = b.db.RunValueLogGC(0.7) + } + return nil +} + +// MaintainRecordStates maintains records states in the database. +func (b *Badger) MaintainRecordStates(ctx context.Context, purgeDeletedBefore time.Time, shadowDelete bool) error { + // TODO: implement MaintainRecordStates + return nil +} + +// Shutdown shuts down the database. +func (b *Badger) Shutdown() error { + return b.db.Close() +} diff --git a/base/database/storage/badger/badger_test.go b/base/database/storage/badger/badger_test.go new file mode 100644 index 000000000..7814f95a5 --- /dev/null +++ b/base/database/storage/badger/badger_test.go @@ -0,0 +1,148 @@ +package badger + +import ( + "context" + "os" + "reflect" + "sync" + "testing" + + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/database/storage" +) + +var ( + // Compile time interface checks. + _ storage.Interface = &Badger{} + _ storage.Maintainer = &Badger{} +) + +type TestRecord struct { //nolint:maligned + record.Base + sync.Mutex + S string + I int + I8 int8 + I16 int16 + I32 int32 + I64 int64 + UI uint + UI8 uint8 + UI16 uint16 + UI32 uint32 + UI64 uint64 + F32 float32 + F64 float64 + B bool +} + +func TestBadger(t *testing.T) { + t.Parallel() + + testDir, err := os.MkdirTemp("", "testing-") + if err != nil { + t.Fatal(err) + } + defer func() { + _ = os.RemoveAll(testDir) // clean up + }() + + // start + db, err := NewBadger("test", testDir) + if err != nil { + t.Fatal(err) + } + + a := &TestRecord{ + S: "banana", + I: 42, + I8: 42, + I16: 42, + I32: 42, + I64: 42, + UI: 42, + UI8: 42, + UI16: 42, + UI32: 42, + UI64: 42, + F32: 42.42, + F64: 42.42, + B: true, + } + a.SetMeta(&record.Meta{}) + a.Meta().Update() + a.SetKey("test:A") + + // put record + _, err = db.Put(a) + if err != nil { + t.Fatal(err) + } + + // get and compare + r1, err := db.Get("A") + if err != nil { + t.Fatal(err) + } + + a1 := &TestRecord{} + err = record.Unwrap(r1, a1) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(a, a1) { + t.Fatalf("mismatch, got %v", a1) + } + + // test query + q := query.New("").MustBeValid() + it, err := db.Query(q, true, true) + if err != nil { + t.Fatal(err) + } + cnt := 0 + for range it.Next { + cnt++ + } + if it.Err() != nil { + t.Fatal(err) + } + if cnt != 1 { + t.Fatalf("unexpected query result count: %d", cnt) + } + + // delete + err = db.Delete("A") + if err != nil { + t.Fatal(err) + } + + // check if its gone + _, err = db.Get("A") + if err == nil { + t.Fatal("should fail") + } + + // maintenance + maintainer, ok := db.(storage.Maintainer) + if ok { + err = maintainer.Maintain(context.TODO()) + if err != nil { + t.Fatal(err) + } + err = maintainer.MaintainThorough(context.TODO()) + if err != nil { + t.Fatal(err) + } + } else { + t.Fatal("should implement Maintainer") + } + + // shutdown + err = db.Shutdown() + if err != nil { + t.Fatal(err) + } +} diff --git a/base/database/storage/bbolt/bbolt.go b/base/database/storage/bbolt/bbolt.go new file mode 100644 index 000000000..3c93309f5 --- /dev/null +++ b/base/database/storage/bbolt/bbolt.go @@ -0,0 +1,427 @@ +package bbolt + +import ( + "bytes" + "context" + "errors" + "fmt" + "path/filepath" + "time" + + "go.etcd.io/bbolt" + + "github.com/safing/portmaster/base/database/iterator" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/database/storage" +) + +var bucketName = []byte{0} + +// BBolt database made pluggable for portbase. +type BBolt struct { + name string + db *bbolt.DB +} + +func init() { + _ = storage.Register("bbolt", NewBBolt) +} + +// NewBBolt opens/creates a bbolt database. +func NewBBolt(name, location string) (storage.Interface, error) { + // Create options for bbolt database. + dbFile := filepath.Join(location, "db.bbolt") + dbOptions := &bbolt.Options{ + Timeout: 1 * time.Second, + } + + // Open/Create database, retry if there is a timeout. + db, err := bbolt.Open(dbFile, 0o0600, dbOptions) + for i := 0; i < 5 && err != nil; i++ { + // Try again if there is an error. + db, err = bbolt.Open(dbFile, 0o0600, dbOptions) + } + if err != nil { + return nil, err + } + + // Create bucket + err = db.Update(func(tx *bbolt.Tx) error { + _, err := tx.CreateBucketIfNotExists(bucketName) + if err != nil { + return err + } + return nil + }) + if err != nil { + return nil, err + } + + return &BBolt{ + name: name, + db: db, + }, nil +} + +// Get returns a database record. +func (b *BBolt) Get(key string) (record.Record, error) { + var r record.Record + + err := b.db.View(func(tx *bbolt.Tx) error { + // get value from db + value := tx.Bucket(bucketName).Get([]byte(key)) + if value == nil { + return storage.ErrNotFound + } + + // copy data + duplicate := make([]byte, len(value)) + copy(duplicate, value) + + // create record + var txErr error + r, txErr = record.NewRawWrapper(b.name, key, duplicate) + if txErr != nil { + return txErr + } + return nil + }) + if err != nil { + return nil, err + } + return r, nil +} + +// GetMeta returns the metadata of a database record. +func (b *BBolt) GetMeta(key string) (*record.Meta, error) { + // TODO: Replace with more performant variant. + + r, err := b.Get(key) + if err != nil { + return nil, err + } + + return r.Meta(), nil +} + +// Put stores a record in the database. +func (b *BBolt) Put(r record.Record) (record.Record, error) { + data, err := r.MarshalRecord(r) + if err != nil { + return nil, err + } + + err = b.db.Update(func(tx *bbolt.Tx) error { + txErr := tx.Bucket(bucketName).Put([]byte(r.DatabaseKey()), data) + if txErr != nil { + return txErr + } + return nil + }) + if err != nil { + return nil, err + } + return r, nil +} + +// PutMany stores many records in the database. +func (b *BBolt) PutMany(shadowDelete bool) (chan<- record.Record, <-chan error) { + batch := make(chan record.Record, 100) + errs := make(chan error, 1) + + go func() { + err := b.db.Batch(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(bucketName) + for r := range batch { + txErr := b.batchPutOrDelete(bucket, shadowDelete, r) + if txErr != nil { + return txErr + } + } + return nil + }) + errs <- err + }() + + return batch, errs +} + +func (b *BBolt) batchPutOrDelete(bucket *bbolt.Bucket, shadowDelete bool, r record.Record) (err error) { + r.Lock() + defer r.Unlock() + + if !shadowDelete && r.Meta().IsDeleted() { + // Immediate delete. + err = bucket.Delete([]byte(r.DatabaseKey())) + } else { + // Put or shadow delete. + var data []byte + data, err = r.MarshalRecord(r) + if err == nil { + err = bucket.Put([]byte(r.DatabaseKey()), data) + } + } + + return err +} + +// Delete deletes a record from the database. +func (b *BBolt) Delete(key string) error { + err := b.db.Update(func(tx *bbolt.Tx) error { + txErr := tx.Bucket(bucketName).Delete([]byte(key)) + if txErr != nil { + return txErr + } + return nil + }) + if err != nil { + return err + } + return nil +} + +// Query returns a an iterator for the supplied query. +func (b *BBolt) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { + _, err := q.Check() + if err != nil { + return nil, fmt.Errorf("invalid query: %w", err) + } + + queryIter := iterator.New() + + go b.queryExecutor(queryIter, q, local, internal) + return queryIter, nil +} + +func (b *BBolt) queryExecutor(queryIter *iterator.Iterator, q *query.Query, local, internal bool) { + prefix := []byte(q.DatabaseKeyPrefix()) + err := b.db.View(func(tx *bbolt.Tx) error { + // Create a cursor for iteration. + c := tx.Bucket(bucketName).Cursor() + + // Iterate over items in sorted key order. This starts from the + // first key/value pair and updates the k/v variables to the + // next key/value on each iteration. + // + // The loop finishes at the end of the cursor when a nil key is returned. + for key, value := c.Seek(prefix); key != nil; key, value = c.Next() { + + // if we don't match the prefix anymore, exit + if !bytes.HasPrefix(key, prefix) { + return nil + } + + // wrap value + iterWrapper, err := record.NewRawWrapper(b.name, string(key), value) + if err != nil { + return err + } + + // check validity / access + if !iterWrapper.Meta().CheckValidity() { + continue + } + if !iterWrapper.Meta().CheckPermission(local, internal) { + continue + } + + // check if matches & send + if q.MatchesRecord(iterWrapper) { + // copy data + duplicate := make([]byte, len(value)) + copy(duplicate, value) + + newWrapper, err := record.NewRawWrapper(b.name, iterWrapper.DatabaseKey(), duplicate) + if err != nil { + return err + } + select { + case <-queryIter.Done: + return nil + case queryIter.Next <- newWrapper: + default: + select { + case <-queryIter.Done: + return nil + case queryIter.Next <- newWrapper: + case <-time.After(1 * time.Second): + return errors.New("query timeout") + } + } + } + } + return nil + }) + queryIter.Finish(err) +} + +// ReadOnly returns whether the database is read only. +func (b *BBolt) ReadOnly() bool { + return false +} + +// Injected returns whether the database is injected. +func (b *BBolt) Injected() bool { + return false +} + +// MaintainRecordStates maintains records states in the database. +func (b *BBolt) MaintainRecordStates(ctx context.Context, purgeDeletedBefore time.Time, shadowDelete bool) error { //nolint:gocognit + now := time.Now().Unix() + purgeThreshold := purgeDeletedBefore.Unix() + + return b.db.Update(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(bucketName) + // Create a cursor for iteration. + c := bucket.Cursor() + for key, value := c.First(); key != nil; key, value = c.Next() { + // check if context is cancelled + select { + case <-ctx.Done(): + return nil + default: + } + + // wrap value + wrapper, err := record.NewRawWrapper(b.name, string(key), value) + if err != nil { + return err + } + + // check if we need to do maintenance + meta := wrapper.Meta() + switch { + case meta.Deleted == 0 && meta.Expires > 0 && meta.Expires < now: + if shadowDelete { + // mark as deleted + meta.Deleted = meta.Expires + deleted, err := wrapper.MarshalRecord(wrapper) + if err != nil { + return err + } + err = bucket.Put(key, deleted) + if err != nil { + return err + } + + // Cursor repositioning is required after modifying data. + // While the documentation states that this is also required after a + // delete, this actually makes the cursor skip a record with the + // following c.Next() call of the loop. + // Docs/Issue: https://github.com/boltdb/bolt/issues/426#issuecomment-141982984 + c.Seek(key) + + continue + } + + // Immediately delete expired entries if shadowDelete is disabled. + fallthrough + case meta.Deleted > 0 && (!shadowDelete || meta.Deleted < purgeThreshold): + // delete from storage + err = c.Delete() + if err != nil { + return err + } + } + } + return nil + }) +} + +// Purge deletes all records that match the given query. It returns the number of successful deletes and an error. +func (b *BBolt) Purge(ctx context.Context, q *query.Query, local, internal, shadowDelete bool) (int, error) { //nolint:gocognit + prefix := []byte(q.DatabaseKeyPrefix()) + + var cnt int + var done bool + for !done { + err := b.db.Update(func(tx *bbolt.Tx) error { + // Create a cursor for iteration. + bucket := tx.Bucket(bucketName) + c := bucket.Cursor() + for key, value := c.Seek(prefix); key != nil; key, value = c.Next() { + // Check if context has been cancelled. + select { + case <-ctx.Done(): + done = true + return nil + default: + } + + // Check if we still match the key prefix, if not, exit. + if !bytes.HasPrefix(key, prefix) { + done = true + return nil + } + + // Wrap the value in a new wrapper to access the metadata. + wrapper, err := record.NewRawWrapper(b.name, string(key), value) + if err != nil { + return err + } + + // Check if we have permission for this record. + if !wrapper.Meta().CheckPermission(local, internal) { + continue + } + + // Check if record is already deleted. + if wrapper.Meta().IsDeleted() { + continue + } + + // Check if the query matches this record. + if !q.MatchesRecord(wrapper) { + continue + } + + // Delete record. + if shadowDelete { + // Shadow delete. + wrapper.Meta().Delete() + deleted, err := wrapper.MarshalRecord(wrapper) + if err != nil { + return err + } + err = bucket.Put(key, deleted) + if err != nil { + return err + } + + // Cursor repositioning is required after modifying data. + // While the documentation states that this is also required after a + // delete, this actually makes the cursor skip a record with the + // following c.Next() call of the loop. + // Docs/Issue: https://github.com/boltdb/bolt/issues/426#issuecomment-141982984 + c.Seek(key) + + } else { + // Immediate delete. + err = c.Delete() + if err != nil { + return err + } + } + + // Work in batches of 1000 changes in order to enable other operations in between. + cnt++ + if cnt%1000 == 0 { + return nil + } + } + done = true + return nil + }) + if err != nil { + return cnt, err + } + } + + return cnt, nil +} + +// Shutdown shuts down the database. +func (b *BBolt) Shutdown() error { + return b.db.Close() +} diff --git a/base/database/storage/bbolt/bbolt_test.go b/base/database/storage/bbolt/bbolt_test.go new file mode 100644 index 000000000..c2fb22b95 --- /dev/null +++ b/base/database/storage/bbolt/bbolt_test.go @@ -0,0 +1,206 @@ +package bbolt + +import ( + "context" + "os" + "reflect" + "sync" + "testing" + "time" + + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/database/storage" +) + +var ( + // Compile time interface checks. + _ storage.Interface = &BBolt{} + _ storage.Batcher = &BBolt{} + _ storage.Purger = &BBolt{} +) + +type TestRecord struct { //nolint:maligned + record.Base + sync.Mutex + S string + I int + I8 int8 + I16 int16 + I32 int32 + I64 int64 + UI uint + UI8 uint8 + UI16 uint16 + UI32 uint32 + UI64 uint64 + F32 float32 + F64 float64 + B bool +} + +func TestBBolt(t *testing.T) { + t.Parallel() + + testDir, err := os.MkdirTemp("", "testing-") + if err != nil { + t.Fatal(err) + } + defer func() { + _ = os.RemoveAll(testDir) // clean up + }() + + // start + db, err := NewBBolt("test", testDir) + if err != nil { + t.Fatal(err) + } + + a := &TestRecord{ + S: "banana", + I: 42, + I8: 42, + I16: 42, + I32: 42, + I64: 42, + UI: 42, + UI8: 42, + UI16: 42, + UI32: 42, + UI64: 42, + F32: 42.42, + F64: 42.42, + B: true, + } + a.SetMeta(&record.Meta{}) + a.Meta().Update() + a.SetKey("test:A") + + // put record + _, err = db.Put(a) + if err != nil { + t.Fatal(err) + } + + // get and compare + r1, err := db.Get("A") + if err != nil { + t.Fatal(err) + } + + a1 := &TestRecord{} + err = record.Unwrap(r1, a1) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(a, a1) { + t.Fatalf("mismatch, got %v", a1) + } + + // setup query test records + qA := &TestRecord{} + qA.SetKey("test:path/to/A") + qA.CreateMeta() + qB := &TestRecord{} + qB.SetKey("test:path/to/B") + qB.CreateMeta() + qC := &TestRecord{} + qC.SetKey("test:path/to/C") + qC.CreateMeta() + qZ := &TestRecord{} + qZ.SetKey("test:z") + qZ.CreateMeta() + // put + _, err = db.Put(qA) + if err == nil { + _, err = db.Put(qB) + } + if err == nil { + _, err = db.Put(qC) + } + if err == nil { + _, err = db.Put(qZ) + } + if err != nil { + t.Fatal(err) + } + + // test query + q := query.New("test:path/to/").MustBeValid() + it, err := db.Query(q, true, true) + if err != nil { + t.Fatal(err) + } + cnt := 0 + for range it.Next { + cnt++ + } + if it.Err() != nil { + t.Fatal(it.Err()) + } + if cnt != 3 { + t.Fatalf("unexpected query result count: %d", cnt) + } + + // delete + err = db.Delete("A") + if err != nil { + t.Fatal(err) + } + + // check if its gone + _, err = db.Get("A") + if err == nil { + t.Fatal("should fail") + } + + // maintenance + err = db.MaintainRecordStates(context.TODO(), time.Now(), true) + if err != nil { + t.Fatal(err) + } + + // maintenance + err = db.MaintainRecordStates(context.TODO(), time.Now(), false) + if err != nil { + t.Fatal(err) + } + + // purging + purger, ok := db.(storage.Purger) + if ok { + n, err := purger.Purge(context.TODO(), query.New("test:path/to/").MustBeValid(), true, true, false) + if err != nil { + t.Fatal(err) + } + if n != 3 { + t.Fatalf("unexpected purge delete count: %d", n) + } + } else { + t.Fatal("should implement Purger") + } + + // test query + q = query.New("test").MustBeValid() + it, err = db.Query(q, true, true) + if err != nil { + t.Fatal(err) + } + cnt = 0 + for range it.Next { + cnt++ + } + if it.Err() != nil { + t.Fatal(it.Err()) + } + if cnt != 1 { + t.Fatalf("unexpected query result count: %d", cnt) + } + + // shutdown + err = db.Shutdown() + if err != nil { + t.Fatal(err) + } +} diff --git a/base/database/storage/errors.go b/base/database/storage/errors.go new file mode 100644 index 000000000..ecc285304 --- /dev/null +++ b/base/database/storage/errors.go @@ -0,0 +1,8 @@ +package storage + +import "errors" + +// Errors for storages. +var ( + ErrNotFound = errors.New("storage entry not found") +) diff --git a/base/database/storage/fstree/fstree.go b/base/database/storage/fstree/fstree.go new file mode 100644 index 000000000..44cf384f2 --- /dev/null +++ b/base/database/storage/fstree/fstree.go @@ -0,0 +1,302 @@ +/* +Package fstree provides a dead simple file-based database storage backend. +It is primarily meant for easy testing or storing big files that can easily be accesses directly, without datastore. +*/ +package fstree + +import ( + "context" + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "runtime" + "strings" + "time" + + "github.com/safing/portmaster/base/database/iterator" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/database/storage" + "github.com/safing/portmaster/base/utils/renameio" +) + +const ( + defaultFileMode = os.FileMode(0o0644) + defaultDirMode = os.FileMode(0o0755) + onWindows = runtime.GOOS == "windows" +) + +// FSTree database storage. +type FSTree struct { + name string + basePath string +} + +func init() { + _ = storage.Register("fstree", NewFSTree) +} + +// NewFSTree returns a (new) FSTree database. +func NewFSTree(name, location string) (storage.Interface, error) { + basePath, err := filepath.Abs(location) + if err != nil { + return nil, fmt.Errorf("fstree: failed to validate path %s: %w", location, err) + } + + file, err := os.Stat(basePath) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + err = os.MkdirAll(basePath, defaultDirMode) + if err != nil { + return nil, fmt.Errorf("fstree: failed to create directory %s: %w", basePath, err) + } + } else { + return nil, fmt.Errorf("fstree: failed to stat path %s: %w", basePath, err) + } + } else { + if !file.IsDir() { + return nil, fmt.Errorf("fstree: provided database path (%s) is a file", basePath) + } + } + + return &FSTree{ + name: name, + basePath: basePath, + }, nil +} + +func (fst *FSTree) buildFilePath(key string, checkKeyLength bool) (string, error) { + // check key length + if checkKeyLength && len(key) < 1 { + return "", fmt.Errorf("fstree: key too short: %s", key) + } + // build filepath + dstPath := filepath.Join(fst.basePath, key) // Join also calls Clean() + if !strings.HasPrefix(dstPath, fst.basePath) { + return "", fmt.Errorf("fstree: key integrity check failed, compiled path is %s", dstPath) + } + // return + return dstPath, nil +} + +// Get returns a database record. +func (fst *FSTree) Get(key string) (record.Record, error) { + dstPath, err := fst.buildFilePath(key, true) + if err != nil { + return nil, err + } + + data, err := os.ReadFile(dstPath) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return nil, storage.ErrNotFound + } + return nil, fmt.Errorf("fstree: failed to read file %s: %w", dstPath, err) + } + + r, err := record.NewRawWrapper(fst.name, key, data) + if err != nil { + return nil, err + } + return r, nil +} + +// GetMeta returns the metadata of a database record. +func (fst *FSTree) GetMeta(key string) (*record.Meta, error) { + // TODO: Replace with more performant variant. + + r, err := fst.Get(key) + if err != nil { + return nil, err + } + + return r.Meta(), nil +} + +// Put stores a record in the database. +func (fst *FSTree) Put(r record.Record) (record.Record, error) { + dstPath, err := fst.buildFilePath(r.DatabaseKey(), true) + if err != nil { + return nil, err + } + + data, err := r.MarshalRecord(r) + if err != nil { + return nil, err + } + + err = writeFile(dstPath, data, defaultFileMode) + if err != nil { + // create dir and try again + err = os.MkdirAll(filepath.Dir(dstPath), defaultDirMode) + if err != nil { + return nil, fmt.Errorf("fstree: failed to create directory %s: %w", filepath.Dir(dstPath), err) + } + err = writeFile(dstPath, data, defaultFileMode) + if err != nil { + return nil, fmt.Errorf("fstree: could not write file %s: %w", dstPath, err) + } + } + + return r, nil +} + +// Delete deletes a record from the database. +func (fst *FSTree) Delete(key string) error { + dstPath, err := fst.buildFilePath(key, true) + if err != nil { + return err + } + + // remove entry + err = os.Remove(dstPath) + if err != nil { + return fmt.Errorf("fstree: could not delete %s: %w", dstPath, err) + } + + return nil +} + +// Query returns a an iterator for the supplied query. +func (fst *FSTree) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { + _, err := q.Check() + if err != nil { + return nil, fmt.Errorf("invalid query: %w", err) + } + + walkPrefix, err := fst.buildFilePath(q.DatabaseKeyPrefix(), false) + if err != nil { + return nil, err + } + fileInfo, err := os.Stat(walkPrefix) + var walkRoot string + switch { + case err == nil && fileInfo.IsDir(): + walkRoot = walkPrefix + case err == nil: + walkRoot = filepath.Dir(walkPrefix) + case errors.Is(err, fs.ErrNotExist): + walkRoot = filepath.Dir(walkPrefix) + default: // err != nil + return nil, fmt.Errorf("fstree: could not stat query root %s: %w", walkPrefix, err) + } + + queryIter := iterator.New() + + go fst.queryExecutor(walkRoot, queryIter, q, local, internal) + return queryIter, nil +} + +func (fst *FSTree) queryExecutor(walkRoot string, queryIter *iterator.Iterator, q *query.Query, local, internal bool) { + err := filepath.Walk(walkRoot, func(path string, info os.FileInfo, err error) error { + if err != nil { + return fmt.Errorf("fstree: error in walking fs: %w", err) + } + + if info.IsDir() { + // skip dir if not in scope + if !strings.HasPrefix(path, fst.basePath) { + return filepath.SkipDir + } + // continue + return nil + } + + // still in scope? + if !strings.HasPrefix(path, fst.basePath) { + return nil + } + + // read file + data, err := os.ReadFile(path) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return nil + } + return fmt.Errorf("fstree: failed to read file %s: %w", path, err) + } + + // parse + key, err := filepath.Rel(fst.basePath, path) + if err != nil { + return fmt.Errorf("fstree: failed to extract key from filepath %s: %w", path, err) + } + r, err := record.NewRawWrapper(fst.name, key, data) + if err != nil { + return fmt.Errorf("fstree: failed to load file %s: %w", path, err) + } + + if !r.Meta().CheckValidity() { + // record is not valid + return nil + } + + if !r.Meta().CheckPermission(local, internal) { + // no permission to access + return nil + } + + // check if matches, then send + if q.MatchesRecord(r) { + select { + case queryIter.Next <- r: + case <-queryIter.Done: + case <-time.After(1 * time.Second): + return errors.New("fstree: query buffer full, timeout") + } + } + + return nil + }) + + queryIter.Finish(err) +} + +// ReadOnly returns whether the database is read only. +func (fst *FSTree) ReadOnly() bool { + return false +} + +// Injected returns whether the database is injected. +func (fst *FSTree) Injected() bool { + return false +} + +// MaintainRecordStates maintains records states in the database. +func (fst *FSTree) MaintainRecordStates(ctx context.Context, purgeDeletedBefore time.Time, shadowDelete bool) error { + // TODO: implement MaintainRecordStates + return nil +} + +// Shutdown shuts down the database. +func (fst *FSTree) Shutdown() error { + return nil +} + +// writeFile mirrors os.WriteFile, replacing an existing file with the same +// name atomically. This is not atomic on Windows, but still an improvement. +// TODO: Replace with github.com/google/renamio.WriteFile as soon as it is fixed on Windows. +// TODO: This has become a wont-fix. Explore other options. +// This function is forked from https://github.com/google/renameio/blob/a368f9987532a68a3d676566141654a81aa8100b/writefile.go. +func writeFile(filename string, data []byte, perm os.FileMode) error { + t, err := renameio.TempFile("", filename) + if err != nil { + return err + } + defer t.Cleanup() //nolint:errcheck + + // Set permissions before writing data, in case the data is sensitive. + if !onWindows { + if err := t.Chmod(perm); err != nil { + return err + } + } + + if _, err := t.Write(data); err != nil { + return err + } + + return t.CloseAtomicallyReplace() +} diff --git a/base/database/storage/fstree/fstree_test.go b/base/database/storage/fstree/fstree_test.go new file mode 100644 index 000000000..7d0c91c8d --- /dev/null +++ b/base/database/storage/fstree/fstree_test.go @@ -0,0 +1,6 @@ +package fstree + +import "github.com/safing/portmaster/base/database/storage" + +// Compile time interface checks. +var _ storage.Interface = &FSTree{} diff --git a/base/database/storage/hashmap/map.go b/base/database/storage/hashmap/map.go new file mode 100644 index 000000000..7114781f4 --- /dev/null +++ b/base/database/storage/hashmap/map.go @@ -0,0 +1,216 @@ +package hashmap + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/safing/portmaster/base/database/iterator" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/database/storage" +) + +// HashMap storage. +type HashMap struct { + name string + db map[string]record.Record + dbLock sync.RWMutex +} + +func init() { + _ = storage.Register("hashmap", NewHashMap) +} + +// NewHashMap creates a hashmap database. +func NewHashMap(name, location string) (storage.Interface, error) { + return &HashMap{ + name: name, + db: make(map[string]record.Record), + }, nil +} + +// Get returns a database record. +func (hm *HashMap) Get(key string) (record.Record, error) { + hm.dbLock.RLock() + defer hm.dbLock.RUnlock() + + r, ok := hm.db[key] + if !ok { + return nil, storage.ErrNotFound + } + return r, nil +} + +// GetMeta returns the metadata of a database record. +func (hm *HashMap) GetMeta(key string) (*record.Meta, error) { + // TODO: Replace with more performant variant. + + r, err := hm.Get(key) + if err != nil { + return nil, err + } + + return r.Meta(), nil +} + +// Put stores a record in the database. +func (hm *HashMap) Put(r record.Record) (record.Record, error) { + hm.dbLock.Lock() + defer hm.dbLock.Unlock() + + hm.db[r.DatabaseKey()] = r + return r, nil +} + +// PutMany stores many records in the database. +func (hm *HashMap) PutMany(shadowDelete bool) (chan<- record.Record, <-chan error) { + hm.dbLock.Lock() + defer hm.dbLock.Unlock() + // we could lock for every record, but we want to have the same behaviour + // as the other storage backends, especially for testing. + + batch := make(chan record.Record, 100) + errs := make(chan error, 1) + + // start handler + go func() { + for r := range batch { + hm.batchPutOrDelete(shadowDelete, r) + } + errs <- nil + }() + + return batch, errs +} + +func (hm *HashMap) batchPutOrDelete(shadowDelete bool, r record.Record) { + r.Lock() + defer r.Unlock() + + hm.dbLock.Lock() + defer hm.dbLock.Unlock() + + if !shadowDelete && r.Meta().IsDeleted() { + delete(hm.db, r.DatabaseKey()) + } else { + hm.db[r.DatabaseKey()] = r + } +} + +// Delete deletes a record from the database. +func (hm *HashMap) Delete(key string) error { + hm.dbLock.Lock() + defer hm.dbLock.Unlock() + + delete(hm.db, key) + return nil +} + +// Query returns a an iterator for the supplied query. +func (hm *HashMap) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { + _, err := q.Check() + if err != nil { + return nil, fmt.Errorf("invalid query: %w", err) + } + + queryIter := iterator.New() + + go hm.queryExecutor(queryIter, q, local, internal) + return queryIter, nil +} + +func (hm *HashMap) queryExecutor(queryIter *iterator.Iterator, q *query.Query, local, internal bool) { + hm.dbLock.RLock() + defer hm.dbLock.RUnlock() + + var err error + +mapLoop: + for key, record := range hm.db { + record.Lock() + if !q.MatchesKey(key) || + !q.MatchesRecord(record) || + !record.Meta().CheckValidity() || + !record.Meta().CheckPermission(local, internal) { + + record.Unlock() + continue + } + record.Unlock() + + select { + case <-queryIter.Done: + break mapLoop + case queryIter.Next <- record: + default: + select { + case <-queryIter.Done: + break mapLoop + case queryIter.Next <- record: + case <-time.After(1 * time.Second): + err = errors.New("query timeout") + break mapLoop + } + } + + } + + queryIter.Finish(err) +} + +// ReadOnly returns whether the database is read only. +func (hm *HashMap) ReadOnly() bool { + return false +} + +// Injected returns whether the database is injected. +func (hm *HashMap) Injected() bool { + return false +} + +// MaintainRecordStates maintains records states in the database. +func (hm *HashMap) MaintainRecordStates(ctx context.Context, purgeDeletedBefore time.Time, shadowDelete bool) error { + hm.dbLock.Lock() + defer hm.dbLock.Unlock() + + now := time.Now().Unix() + purgeThreshold := purgeDeletedBefore.Unix() + + for key, record := range hm.db { + // check if context is cancelled + select { + case <-ctx.Done(): + return nil + default: + } + + meta := record.Meta() + switch { + case meta.Deleted == 0 && meta.Expires > 0 && meta.Expires < now: + if shadowDelete { + // mark as deleted + record.Lock() + meta.Deleted = meta.Expires + record.Unlock() + + continue + } + + // Immediately delete expired entries if shadowDelete is disabled. + fallthrough + case meta.Deleted > 0 && (!shadowDelete || meta.Deleted < purgeThreshold): + // delete from storage + delete(hm.db, key) + } + } + + return nil +} + +// Shutdown shuts down the database. +func (hm *HashMap) Shutdown() error { + return nil +} diff --git a/base/database/storage/hashmap/map_test.go b/base/database/storage/hashmap/map_test.go new file mode 100644 index 000000000..ff310c2d0 --- /dev/null +++ b/base/database/storage/hashmap/map_test.go @@ -0,0 +1,145 @@ +package hashmap + +import ( + "reflect" + "sync" + "testing" + + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/database/storage" +) + +var ( + // Compile time interface checks. + _ storage.Interface = &HashMap{} + _ storage.Batcher = &HashMap{} +) + +type TestRecord struct { //nolint:maligned + record.Base + sync.Mutex + S string + I int + I8 int8 + I16 int16 + I32 int32 + I64 int64 + UI uint + UI8 uint8 + UI16 uint16 + UI32 uint32 + UI64 uint64 + F32 float32 + F64 float64 + B bool +} + +func TestHashMap(t *testing.T) { + t.Parallel() + + // start + db, err := NewHashMap("test", "") + if err != nil { + t.Fatal(err) + } + + a := &TestRecord{ + S: "banana", + I: 42, + I8: 42, + I16: 42, + I32: 42, + I64: 42, + UI: 42, + UI8: 42, + UI16: 42, + UI32: 42, + UI64: 42, + F32: 42.42, + F64: 42.42, + B: true, + } + a.SetMeta(&record.Meta{}) + a.Meta().Update() + a.SetKey("test:A") + + // put record + _, err = db.Put(a) + if err != nil { + t.Fatal(err) + } + + // get and compare + a1, err := db.Get("A") + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(a, a1) { + t.Fatalf("mismatch, got %v", a1) + } + + // setup query test records + qA := &TestRecord{} + qA.SetKey("test:path/to/A") + qA.CreateMeta() + qB := &TestRecord{} + qB.SetKey("test:path/to/B") + qB.CreateMeta() + qC := &TestRecord{} + qC.SetKey("test:path/to/C") + qC.CreateMeta() + qZ := &TestRecord{} + qZ.SetKey("test:z") + qZ.CreateMeta() + // put + _, err = db.Put(qA) + if err == nil { + _, err = db.Put(qB) + } + if err == nil { + _, err = db.Put(qC) + } + if err == nil { + _, err = db.Put(qZ) + } + if err != nil { + t.Fatal(err) + } + + // test query + q := query.New("test:path/to/").MustBeValid() + it, err := db.Query(q, true, true) + if err != nil { + t.Fatal(err) + } + cnt := 0 + for range it.Next { + cnt++ + } + if it.Err() != nil { + t.Fatal(it.Err()) + } + if cnt != 3 { + t.Fatalf("unexpected query result count: %d", cnt) + } + + // delete + err = db.Delete("A") + if err != nil { + t.Fatal(err) + } + + // check if its gone + _, err = db.Get("A") + if err == nil { + t.Fatal("should fail") + } + + // shutdown + err = db.Shutdown() + if err != nil { + t.Fatal(err) + } +} diff --git a/base/database/storage/injectbase.go b/base/database/storage/injectbase.go new file mode 100644 index 000000000..e7bee7328 --- /dev/null +++ b/base/database/storage/injectbase.go @@ -0,0 +1,60 @@ +package storage + +import ( + "context" + "errors" + "time" + + "github.com/safing/portmaster/base/database/iterator" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" +) + +// ErrNotImplemented is returned when a function is not implemented by a storage. +var ErrNotImplemented = errors.New("not implemented") + +// InjectBase is a dummy base structure to reduce boilerplate code for injected storage interfaces. +type InjectBase struct{} + +// Compile time interface check. +var _ Interface = &InjectBase{} + +// Get returns a database record. +func (i *InjectBase) Get(key string) (record.Record, error) { + return nil, ErrNotImplemented +} + +// Put stores a record in the database. +func (i *InjectBase) Put(m record.Record) (record.Record, error) { + return nil, ErrNotImplemented +} + +// Delete deletes a record from the database. +func (i *InjectBase) Delete(key string) error { + return ErrNotImplemented +} + +// Query returns a an iterator for the supplied query. +func (i *InjectBase) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { + return nil, ErrNotImplemented +} + +// ReadOnly returns whether the database is read only. +func (i *InjectBase) ReadOnly() bool { + return true +} + +// Injected returns whether the database is injected. +func (i *InjectBase) Injected() bool { + return true +} + +// MaintainRecordStates maintains records states in the database. +func (i *InjectBase) MaintainRecordStates(ctx context.Context, purgeDeletedBefore time.Time, shadowDelete bool) error { + return nil +} + +// Shutdown shuts down the database. +func (i *InjectBase) Shutdown() error { + return nil +} diff --git a/base/database/storage/interface.go b/base/database/storage/interface.go new file mode 100644 index 000000000..c329a0a64 --- /dev/null +++ b/base/database/storage/interface.go @@ -0,0 +1,48 @@ +package storage + +import ( + "context" + "time" + + "github.com/safing/portmaster/base/database/iterator" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" +) + +// Interface defines the database storage API. +type Interface interface { + // Primary Interface + Get(key string) (record.Record, error) + Put(m record.Record) (record.Record, error) + Delete(key string) error + Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) + + // Information and Control + ReadOnly() bool + Injected() bool + Shutdown() error + + // Mandatory Record Maintenance + MaintainRecordStates(ctx context.Context, purgeDeletedBefore time.Time, shadowDelete bool) error +} + +// MetaHandler defines the database storage API for backends that support optimized fetching of only the metadata. +type MetaHandler interface { + GetMeta(key string) (*record.Meta, error) +} + +// Maintainer defines the database storage API for backends that require regular maintenance. +type Maintainer interface { + Maintain(ctx context.Context) error + MaintainThorough(ctx context.Context) error +} + +// Batcher defines the database storage API for backends that support batch operations. +type Batcher interface { + PutMany(shadowDelete bool) (batch chan<- record.Record, errs <-chan error) +} + +// Purger defines the database storage API for backends that support the purge operation. +type Purger interface { + Purge(ctx context.Context, q *query.Query, local, internal, shadowDelete bool) (int, error) +} diff --git a/base/database/storage/sinkhole/sinkhole.go b/base/database/storage/sinkhole/sinkhole.go new file mode 100644 index 000000000..eb338b76f --- /dev/null +++ b/base/database/storage/sinkhole/sinkhole.go @@ -0,0 +1,111 @@ +package sinkhole + +import ( + "context" + "errors" + "time" + + "github.com/safing/portmaster/base/database/iterator" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/database/storage" +) + +// Sinkhole is a dummy storage. +type Sinkhole struct { + name string +} + +var ( + // Compile time interface checks. + _ storage.Interface = &Sinkhole{} + _ storage.Maintainer = &Sinkhole{} + _ storage.Batcher = &Sinkhole{} +) + +func init() { + _ = storage.Register("sinkhole", NewSinkhole) +} + +// NewSinkhole creates a dummy database. +func NewSinkhole(name, location string) (storage.Interface, error) { + return &Sinkhole{ + name: name, + }, nil +} + +// Exists returns whether an entry with the given key exists. +func (s *Sinkhole) Exists(key string) (bool, error) { + return false, nil +} + +// Get returns a database record. +func (s *Sinkhole) Get(key string) (record.Record, error) { + return nil, storage.ErrNotFound +} + +// GetMeta returns the metadata of a database record. +func (s *Sinkhole) GetMeta(key string) (*record.Meta, error) { + return nil, storage.ErrNotFound +} + +// Put stores a record in the database. +func (s *Sinkhole) Put(r record.Record) (record.Record, error) { + return r, nil +} + +// PutMany stores many records in the database. +func (s *Sinkhole) PutMany(shadowDelete bool) (chan<- record.Record, <-chan error) { + batch := make(chan record.Record, 100) + errs := make(chan error, 1) + + // start handler + go func() { + for range batch { + // discard everything + } + errs <- nil + }() + + return batch, errs +} + +// Delete deletes a record from the database. +func (s *Sinkhole) Delete(key string) error { + return nil +} + +// Query returns a an iterator for the supplied query. +func (s *Sinkhole) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { + return nil, errors.New("query not implemented by sinkhole") +} + +// ReadOnly returns whether the database is read only. +func (s *Sinkhole) ReadOnly() bool { + return false +} + +// Injected returns whether the database is injected. +func (s *Sinkhole) Injected() bool { + return false +} + +// Maintain runs a light maintenance operation on the database. +func (s *Sinkhole) Maintain(ctx context.Context) error { + return nil +} + +// MaintainThorough runs a thorough maintenance operation on the database. +func (s *Sinkhole) MaintainThorough(ctx context.Context) error { + return nil +} + +// MaintainRecordStates maintains records states in the database. +func (s *Sinkhole) MaintainRecordStates(ctx context.Context, purgeDeletedBefore time.Time, shadowDelete bool) error { + return nil +} + +// Shutdown shuts down the database. +func (s *Sinkhole) Shutdown() error { + return nil +} diff --git a/base/database/storage/storages.go b/base/database/storage/storages.go new file mode 100644 index 000000000..1fa744836 --- /dev/null +++ b/base/database/storage/storages.go @@ -0,0 +1,47 @@ +package storage + +import ( + "errors" + "fmt" + "sync" +) + +// A Factory creates a new database of it's type. +type Factory func(name, location string) (Interface, error) + +var ( + storages = make(map[string]Factory) + storagesLock sync.Mutex +) + +// Register registers a new storage type. +func Register(name string, factory Factory) error { + storagesLock.Lock() + defer storagesLock.Unlock() + + _, ok := storages[name] + if ok { + return errors.New("factory for this type already exists") + } + + storages[name] = factory + return nil +} + +// CreateDatabase starts a new database with the given name and storageType at location. +func CreateDatabase(name, storageType, location string) (Interface, error) { + return nil, nil +} + +// StartDatabase starts a new database with the given name and storageType at location. +func StartDatabase(name, storageType, location string) (Interface, error) { + storagesLock.Lock() + defer storagesLock.Unlock() + + factory, ok := storages[storageType] + if !ok { + return nil, fmt.Errorf("storage type %s not registered", storageType) + } + + return factory(name, location) +} diff --git a/base/database/subscription.go b/base/database/subscription.go new file mode 100644 index 000000000..b995ecd6e --- /dev/null +++ b/base/database/subscription.go @@ -0,0 +1,35 @@ +package database + +import ( + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" +) + +// Subscription is a database subscription for updates. +type Subscription struct { + q *query.Query + local bool + internal bool + + Feed chan record.Record +} + +// Cancel cancels the subscription. +func (s *Subscription) Cancel() error { + c, err := getController(s.q.DatabaseName()) + if err != nil { + return err + } + + c.subscriptionLock.Lock() + defer c.subscriptionLock.Unlock() + + for key, sub := range c.subscriptions { + if sub.q == s.q { + c.subscriptions = append(c.subscriptions[:key], c.subscriptions[key+1:]...) + close(s.Feed) // this close is guarded by the controllers subscriptionLock. + return nil + } + } + return nil +} diff --git a/base/dataroot/root.go b/base/dataroot/root.go new file mode 100644 index 000000000..296b342f3 --- /dev/null +++ b/base/dataroot/root.go @@ -0,0 +1,25 @@ +package dataroot + +import ( + "errors" + "os" + + "github.com/safing/portmaster/base/utils" +) + +var root *utils.DirStructure + +// Initialize initializes the data root directory. +func Initialize(rootDir string, perm os.FileMode) error { + if root != nil { + return errors.New("already initialized") + } + + root = utils.NewDirStructure(rootDir, perm) + return root.Ensure() +} + +// Root returns the data root directory. +func Root() *utils.DirStructure { + return root +} diff --git a/base/formats/dsd/compression.go b/base/formats/dsd/compression.go new file mode 100644 index 000000000..d1baf2822 --- /dev/null +++ b/base/formats/dsd/compression.go @@ -0,0 +1,103 @@ +package dsd + +import ( + "bytes" + "compress/gzip" + "errors" + + "github.com/safing/portmaster/base/formats/varint" +) + +// DumpAndCompress stores the interface as a dsd formatted data structure and compresses the resulting data. +func DumpAndCompress(t interface{}, format uint8, compression uint8) ([]byte, error) { + // Check if compression format is valid. + compression, ok := ValidateCompressionFormat(compression) + if !ok { + return nil, ErrIncompatibleFormat + } + + // Dump the given data with the given format. + data, err := Dump(t, format) + if err != nil { + return nil, err + } + + // prepare writer + packetFormat := varint.Pack8(compression) + buf := bytes.NewBuffer(nil) + buf.Write(packetFormat) + + // compress + switch compression { + case GZIP: + // create gzip writer + gzipWriter, err := gzip.NewWriterLevel(buf, gzip.BestCompression) + if err != nil { + return nil, err + } + + // write data + n, err := gzipWriter.Write(data) + if err != nil { + return nil, err + } + if n != len(data) { + return nil, errors.New("failed to fully write to gzip compressor") + } + + // flush and write gzip footer + err = gzipWriter.Close() + if err != nil { + return nil, err + } + default: + return nil, ErrIncompatibleFormat + } + + return buf.Bytes(), nil +} + +// DecompressAndLoad decompresses the data using the specified compression format and then loads the resulting data blob into the interface. +func DecompressAndLoad(data []byte, compression uint8, t interface{}) (format uint8, err error) { + // Check if compression format is valid. + _, ok := ValidateCompressionFormat(compression) + if !ok { + return 0, ErrIncompatibleFormat + } + + // prepare reader + buf := bytes.NewBuffer(nil) + + // decompress + switch compression { + case GZIP: + // create gzip reader + gzipReader, err := gzip.NewReader(bytes.NewBuffer(data)) + if err != nil { + return 0, err + } + + // read uncompressed data + _, err = buf.ReadFrom(gzipReader) + if err != nil { + return 0, err + } + + // flush and verify gzip footer + err = gzipReader.Close() + if err != nil { + return 0, err + } + default: + return 0, ErrIncompatibleFormat + } + + // assign decompressed data + data = buf.Bytes() + + format, read, err := loadFormat(data) + if err != nil { + return 0, err + } + return format, LoadAsFormat(data[read:], format, t) +} diff --git a/base/formats/dsd/dsd.go b/base/formats/dsd/dsd.go new file mode 100644 index 000000000..76b8c4446 --- /dev/null +++ b/base/formats/dsd/dsd.go @@ -0,0 +1,160 @@ +package dsd + +// dynamic structured data +// check here for some benchmarks: https://github.com/alecthomas/go_serialization_benchmarks + +import ( + "encoding/json" + "errors" + "fmt" + "io" + + "github.com/fxamacker/cbor/v2" + "github.com/ghodss/yaml" + "github.com/vmihailenco/msgpack/v5" + + "github.com/safing/portmaster/base/formats/varint" + "github.com/safing/portmaster/base/utils" +) + +// Load loads an dsd structured data blob into the given interface. +func Load(data []byte, t interface{}) (format uint8, err error) { + format, read, err := loadFormat(data) + if err != nil { + return 0, err + } + + _, ok := ValidateSerializationFormat(format) + if ok { + return format, LoadAsFormat(data[read:], format, t) + } + return DecompressAndLoad(data[read:], format, t) +} + +// LoadAsFormat loads a data blob into the interface using the specified format. +func LoadAsFormat(data []byte, format uint8, t interface{}) (err error) { + switch format { + case RAW: + return ErrIsRaw + case JSON: + err = json.Unmarshal(data, t) + if err != nil { + return fmt.Errorf("dsd: failed to unpack json: %w, data: %s", err, utils.SafeFirst16Bytes(data)) + } + return nil + case YAML: + err = yaml.Unmarshal(data, t) + if err != nil { + return fmt.Errorf("dsd: failed to unpack yaml: %w, data: %s", err, utils.SafeFirst16Bytes(data)) + } + return nil + case CBOR: + err = cbor.Unmarshal(data, t) + if err != nil { + return fmt.Errorf("dsd: failed to unpack cbor: %w, data: %s", err, utils.SafeFirst16Bytes(data)) + } + return nil + case MsgPack: + err = msgpack.Unmarshal(data, t) + if err != nil { + return fmt.Errorf("dsd: failed to unpack msgpack: %w, data: %s", err, utils.SafeFirst16Bytes(data)) + } + return nil + case GenCode: + genCodeStruct, ok := t.(GenCodeCompatible) + if !ok { + return errors.New("dsd: gencode is not supported by the given data structure") + } + _, err = genCodeStruct.GenCodeUnmarshal(data) + if err != nil { + return fmt.Errorf("dsd: failed to unpack gencode: %w, data: %s", err, utils.SafeFirst16Bytes(data)) + } + return nil + default: + return ErrIncompatibleFormat + } +} + +func loadFormat(data []byte) (format uint8, read int, err error) { + format, read, err = varint.Unpack8(data) + if err != nil { + return 0, 0, err + } + if len(data) <= read { + return 0, 0, io.ErrUnexpectedEOF + } + + return format, read, nil +} + +// Dump stores the interface as a dsd formatted data structure. +func Dump(t interface{}, format uint8) ([]byte, error) { + return DumpIndent(t, format, "") +} + +// DumpIndent stores the interface as a dsd formatted data structure with indentation, if available. +func DumpIndent(t interface{}, format uint8, indent string) ([]byte, error) { + data, err := dumpWithoutIdentifier(t, format, indent) + if err != nil { + return nil, err + } + + // TODO: Find a better way to do this. + return append(varint.Pack8(format), data...), nil +} + +func dumpWithoutIdentifier(t interface{}, format uint8, indent string) ([]byte, error) { + format, ok := ValidateSerializationFormat(format) + if !ok { + return nil, ErrIncompatibleFormat + } + + var data []byte + var err error + switch format { + case RAW: + var ok bool + data, ok = t.([]byte) + if !ok { + return nil, ErrIncompatibleFormat + } + case JSON: + // TODO: use SetEscapeHTML(false) + if indent != "" { + data, err = json.MarshalIndent(t, "", indent) + } else { + data, err = json.Marshal(t) + } + if err != nil { + return nil, err + } + case YAML: + data, err = yaml.Marshal(t) + if err != nil { + return nil, err + } + case CBOR: + data, err = cbor.Marshal(t) + if err != nil { + return nil, err + } + case MsgPack: + data, err = msgpack.Marshal(t) + if err != nil { + return nil, err + } + case GenCode: + genCodeStruct, ok := t.(GenCodeCompatible) + if !ok { + return nil, errors.New("dsd: gencode is not supported by the given data structure") + } + data, err = genCodeStruct.GenCodeMarshal(nil) + if err != nil { + return nil, fmt.Errorf("dsd: failed to pack gencode struct: %w", err) + } + default: + return nil, ErrIncompatibleFormat + } + + return data, nil +} diff --git a/base/formats/dsd/dsd_test.go b/base/formats/dsd/dsd_test.go new file mode 100644 index 000000000..479f72711 --- /dev/null +++ b/base/formats/dsd/dsd_test.go @@ -0,0 +1,327 @@ +//nolint:maligned,gocyclo,gocognit +package dsd + +import ( + "math/big" + "reflect" + "testing" +) + +// SimpleTestStruct is used for testing. +type SimpleTestStruct struct { + S string + B byte +} + +type ComplexTestStruct struct { + I int + I8 int8 + I16 int16 + I32 int32 + I64 int64 + UI uint + UI8 uint8 + UI16 uint16 + UI32 uint32 + UI64 uint64 + BI *big.Int + S string + Sp *string + Sa []string + Sap *[]string + B byte + Bp *byte + Ba []byte + Bap *[]byte + M map[string]string + Mp *map[string]string +} + +type GenCodeTestStruct struct { + I8 int8 + I16 int16 + I32 int32 + I64 int64 + UI8 uint8 + UI16 uint16 + UI32 uint32 + UI64 uint64 + S string + Sp *string + Sa []string + Sap *[]string + B byte + Bp *byte + Ba []byte + Bap *[]byte +} + +var ( + simpleSubject = &SimpleTestStruct{ + "a", + 0x01, + } + + bString = "b" + bBytes byte = 0x02 + + complexSubject = &ComplexTestStruct{ + -1, + -2, + -3, + -4, + -5, + 1, + 2, + 3, + 4, + 5, + big.NewInt(6), + "a", + &bString, + []string{"c", "d", "e"}, + &[]string{"f", "g", "h"}, + 0x01, + &bBytes, + []byte{0x03, 0x04, 0x05}, + &[]byte{0x05, 0x06, 0x07}, + map[string]string{ + "a": "b", + "c": "d", + "e": "f", + }, + &map[string]string{ + "g": "h", + "i": "j", + "k": "l", + }, + } + + genCodeSubject = &GenCodeTestStruct{ + -2, + -3, + -4, + -5, + 2, + 3, + 4, + 5, + "a", + &bString, + []string{"c", "d", "e"}, + &[]string{"f", "g", "h"}, + 0x01, + &bBytes, + []byte{0x03, 0x04, 0x05}, + &[]byte{0x05, 0x06, 0x07}, + } +) + +func TestConversion(t *testing.T) { //nolint:maintidx + t.Parallel() + + compressionFormats := []uint8{AUTO, GZIP} + formats := []uint8{JSON, CBOR, MsgPack} + + for _, compression := range compressionFormats { + for _, format := range formats { + + // simple + var b []byte + var err error + if compression != AUTO { + b, err = DumpAndCompress(simpleSubject, format, compression) + } else { + b, err = Dump(simpleSubject, format) + } + if err != nil { + t.Fatalf("Dump error (simple struct): %s", err) + } + + si := &SimpleTestStruct{} + _, err = Load(b, si) + if err != nil { + t.Fatalf("Load error (simple struct): %s", err) + } + + if !reflect.DeepEqual(simpleSubject, si) { + t.Errorf("Load (simple struct): subject does not match loaded object") + t.Errorf("Encoded: %v", string(b)) + t.Errorf("Compared: %v == %v", simpleSubject, si) + } + + // complex + if compression != AUTO { + b, err = DumpAndCompress(complexSubject, format, compression) + } else { + b, err = Dump(complexSubject, format) + } + if err != nil { + t.Fatalf("Dump error (complex struct): %s", err) + } + + co := &ComplexTestStruct{} + _, err = Load(b, co) + if err != nil { + t.Fatalf("Load error (complex struct): %s", err) + } + + if complexSubject.I != co.I { + t.Errorf("Load (complex struct): struct.I is not equal (%v != %v)", complexSubject.I, co.I) + } + if complexSubject.I8 != co.I8 { + t.Errorf("Load (complex struct): struct.I8 is not equal (%v != %v)", complexSubject.I8, co.I8) + } + if complexSubject.I16 != co.I16 { + t.Errorf("Load (complex struct): struct.I16 is not equal (%v != %v)", complexSubject.I16, co.I16) + } + if complexSubject.I32 != co.I32 { + t.Errorf("Load (complex struct): struct.I32 is not equal (%v != %v)", complexSubject.I32, co.I32) + } + if complexSubject.I64 != co.I64 { + t.Errorf("Load (complex struct): struct.I64 is not equal (%v != %v)", complexSubject.I64, co.I64) + } + if complexSubject.UI != co.UI { + t.Errorf("Load (complex struct): struct.UI is not equal (%v != %v)", complexSubject.UI, co.UI) + } + if complexSubject.UI8 != co.UI8 { + t.Errorf("Load (complex struct): struct.UI8 is not equal (%v != %v)", complexSubject.UI8, co.UI8) + } + if complexSubject.UI16 != co.UI16 { + t.Errorf("Load (complex struct): struct.UI16 is not equal (%v != %v)", complexSubject.UI16, co.UI16) + } + if complexSubject.UI32 != co.UI32 { + t.Errorf("Load (complex struct): struct.UI32 is not equal (%v != %v)", complexSubject.UI32, co.UI32) + } + if complexSubject.UI64 != co.UI64 { + t.Errorf("Load (complex struct): struct.UI64 is not equal (%v != %v)", complexSubject.UI64, co.UI64) + } + if complexSubject.BI.Cmp(co.BI) != 0 { + t.Errorf("Load (complex struct): struct.BI is not equal (%v != %v)", complexSubject.BI, co.BI) + } + if complexSubject.S != co.S { + t.Errorf("Load (complex struct): struct.S is not equal (%v != %v)", complexSubject.S, co.S) + } + if !reflect.DeepEqual(complexSubject.Sp, co.Sp) { + t.Errorf("Load (complex struct): struct.Sp is not equal (%v != %v)", complexSubject.Sp, co.Sp) + } + if !reflect.DeepEqual(complexSubject.Sa, co.Sa) { + t.Errorf("Load (complex struct): struct.Sa is not equal (%v != %v)", complexSubject.Sa, co.Sa) + } + if !reflect.DeepEqual(complexSubject.Sap, co.Sap) { + t.Errorf("Load (complex struct): struct.Sap is not equal (%v != %v)", complexSubject.Sap, co.Sap) + } + if complexSubject.B != co.B { + t.Errorf("Load (complex struct): struct.B is not equal (%v != %v)", complexSubject.B, co.B) + } + if !reflect.DeepEqual(complexSubject.Bp, co.Bp) { + t.Errorf("Load (complex struct): struct.Bp is not equal (%v != %v)", complexSubject.Bp, co.Bp) + } + if !reflect.DeepEqual(complexSubject.Ba, co.Ba) { + t.Errorf("Load (complex struct): struct.Ba is not equal (%v != %v)", complexSubject.Ba, co.Ba) + } + if !reflect.DeepEqual(complexSubject.Bap, co.Bap) { + t.Errorf("Load (complex struct): struct.Bap is not equal (%v != %v)", complexSubject.Bap, co.Bap) + } + if !reflect.DeepEqual(complexSubject.M, co.M) { + t.Errorf("Load (complex struct): struct.M is not equal (%v != %v)", complexSubject.M, co.M) + } + if !reflect.DeepEqual(complexSubject.Mp, co.Mp) { + t.Errorf("Load (complex struct): struct.Mp is not equal (%v != %v)", complexSubject.Mp, co.Mp) + } + + } + + // test all formats + simplifiedFormatTesting := []uint8{JSON, CBOR, MsgPack, GenCode} + + for _, format := range simplifiedFormatTesting { + + // simple + var b []byte + var err error + if compression != AUTO { + b, err = DumpAndCompress(simpleSubject, format, compression) + } else { + b, err = Dump(simpleSubject, format) + } + if err != nil { + t.Fatalf("Dump error (simple struct): %s", err) + } + + si := &SimpleTestStruct{} + _, err = Load(b, si) + if err != nil { + t.Fatalf("Load error (simple struct): %s", err) + } + + if !reflect.DeepEqual(simpleSubject, si) { + t.Errorf("Load (simple struct): subject does not match loaded object") + t.Errorf("Encoded: %v", string(b)) + t.Errorf("Compared: %v == %v", simpleSubject, si) + } + + // complex + b, err = DumpAndCompress(genCodeSubject, format, compression) + if err != nil { + t.Fatalf("Dump error (complex struct): %s", err) + } + + co := &GenCodeTestStruct{} + _, err = Load(b, co) + if err != nil { + t.Fatalf("Load error (complex struct): %s", err) + } + + if genCodeSubject.I8 != co.I8 { + t.Errorf("Load (complex struct): struct.I8 is not equal (%v != %v)", genCodeSubject.I8, co.I8) + } + if genCodeSubject.I16 != co.I16 { + t.Errorf("Load (complex struct): struct.I16 is not equal (%v != %v)", genCodeSubject.I16, co.I16) + } + if genCodeSubject.I32 != co.I32 { + t.Errorf("Load (complex struct): struct.I32 is not equal (%v != %v)", genCodeSubject.I32, co.I32) + } + if genCodeSubject.I64 != co.I64 { + t.Errorf("Load (complex struct): struct.I64 is not equal (%v != %v)", genCodeSubject.I64, co.I64) + } + if genCodeSubject.UI8 != co.UI8 { + t.Errorf("Load (complex struct): struct.UI8 is not equal (%v != %v)", genCodeSubject.UI8, co.UI8) + } + if genCodeSubject.UI16 != co.UI16 { + t.Errorf("Load (complex struct): struct.UI16 is not equal (%v != %v)", genCodeSubject.UI16, co.UI16) + } + if genCodeSubject.UI32 != co.UI32 { + t.Errorf("Load (complex struct): struct.UI32 is not equal (%v != %v)", genCodeSubject.UI32, co.UI32) + } + if genCodeSubject.UI64 != co.UI64 { + t.Errorf("Load (complex struct): struct.UI64 is not equal (%v != %v)", genCodeSubject.UI64, co.UI64) + } + if genCodeSubject.S != co.S { + t.Errorf("Load (complex struct): struct.S is not equal (%v != %v)", genCodeSubject.S, co.S) + } + if !reflect.DeepEqual(genCodeSubject.Sp, co.Sp) { + t.Errorf("Load (complex struct): struct.Sp is not equal (%v != %v)", genCodeSubject.Sp, co.Sp) + } + if !reflect.DeepEqual(genCodeSubject.Sa, co.Sa) { + t.Errorf("Load (complex struct): struct.Sa is not equal (%v != %v)", genCodeSubject.Sa, co.Sa) + } + if !reflect.DeepEqual(genCodeSubject.Sap, co.Sap) { + t.Errorf("Load (complex struct): struct.Sap is not equal (%v != %v)", genCodeSubject.Sap, co.Sap) + } + if genCodeSubject.B != co.B { + t.Errorf("Load (complex struct): struct.B is not equal (%v != %v)", genCodeSubject.B, co.B) + } + if !reflect.DeepEqual(genCodeSubject.Bp, co.Bp) { + t.Errorf("Load (complex struct): struct.Bp is not equal (%v != %v)", genCodeSubject.Bp, co.Bp) + } + if !reflect.DeepEqual(genCodeSubject.Ba, co.Ba) { + t.Errorf("Load (complex struct): struct.Ba is not equal (%v != %v)", genCodeSubject.Ba, co.Ba) + } + if !reflect.DeepEqual(genCodeSubject.Bap, co.Bap) { + t.Errorf("Load (complex struct): struct.Bap is not equal (%v != %v)", genCodeSubject.Bap, co.Bap) + } + } + + } +} diff --git a/base/formats/dsd/format.go b/base/formats/dsd/format.go new file mode 100644 index 000000000..c97950464 --- /dev/null +++ b/base/formats/dsd/format.go @@ -0,0 +1,73 @@ +package dsd + +import "errors" + +// Errors. +var ( + ErrIncompatibleFormat = errors.New("dsd: format is incompatible with operation") + ErrIsRaw = errors.New("dsd: given data is in raw format") + ErrUnknownFormat = errors.New("dsd: format is unknown") +) + +// Format types. +const ( + AUTO = 0 + + // Serialization types. + RAW = 1 + CBOR = 67 // C + GenCode = 71 // G + JSON = 74 // J + MsgPack = 77 // M + YAML = 89 // Y + + // Compression types. + GZIP = 90 // Z + + // Special types. + LIST = 76 // L +) + +// Default Formats. +var ( + DefaultSerializationFormat uint8 = JSON + DefaultCompressionFormat uint8 = GZIP +) + +// ValidateSerializationFormat validates if the format is for serialization, +// and returns the validated format as well as the result of the validation. +// If called on the AUTO format, it returns the default serialization format. +func ValidateSerializationFormat(format uint8) (validatedFormat uint8, ok bool) { + switch format { + case AUTO: + return DefaultSerializationFormat, true + case RAW: + return format, true + case CBOR: + return format, true + case GenCode: + return format, true + case JSON: + return format, true + case YAML: + return format, true + case MsgPack: + return format, true + default: + return 0, false + } +} + +// ValidateCompressionFormat validates if the format is for compression, +// and returns the validated format as well as the result of the validation. +// If called on the AUTO format, it returns the default compression format. +func ValidateCompressionFormat(format uint8) (validatedFormat uint8, ok bool) { + switch format { + case AUTO: + return DefaultCompressionFormat, true + case GZIP: + return format, true + default: + return 0, false + } +} diff --git a/base/formats/dsd/gencode_test.go b/base/formats/dsd/gencode_test.go new file mode 100644 index 000000000..2fbf18a00 --- /dev/null +++ b/base/formats/dsd/gencode_test.go @@ -0,0 +1,824 @@ +//nolint:nakedret,unconvert,gocognit,wastedassign,gofumpt +package dsd + +func (d *SimpleTestStruct) Size() (s uint64) { + + { + l := uint64(len(d.S)) + + { + + t := l + for t >= 0x80 { + t >>= 7 + s++ + } + s++ + + } + s += l + } + s++ + return +} + +func (d *SimpleTestStruct) GenCodeMarshal(buf []byte) ([]byte, error) { + size := d.Size() + { + if uint64(cap(buf)) >= size { + buf = buf[:size] + } else { + buf = make([]byte, size) + } + } + i := uint64(0) + + { + l := uint64(len(d.S)) + + { + + t := uint64(l) + + for t >= 0x80 { + buf[i+0] = byte(t) | 0x80 + t >>= 7 + i++ + } + buf[i+0] = byte(t) + i++ + + } + copy(buf[i+0:], d.S) + i += l + } + { + buf[i+0] = d.B + } + return buf[:i+1], nil +} + +func (d *SimpleTestStruct) GenCodeUnmarshal(buf []byte) (uint64, error) { + i := uint64(0) + + { + l := uint64(0) + + { + + bs := uint8(7) + t := uint64(buf[i+0] & 0x7F) + for buf[i+0]&0x80 == 0x80 { + i++ + t |= uint64(buf[i+0]&0x7F) << bs + bs += 7 + } + i++ + + l = t + + } + d.S = string(buf[i+0 : i+0+l]) + i += l + } + { + d.B = buf[i+0] + } + return i + 1, nil +} + +func (d *GenCodeTestStruct) Size() (s uint64) { + + { + l := uint64(len(d.S)) + + { + + t := l + for t >= 0x80 { + t >>= 7 + s++ + } + s++ + + } + s += l + } + { + if d.Sp != nil { + + { + l := uint64(len((*d.Sp))) + + { + + t := l + for t >= 0x80 { + t >>= 7 + s++ + } + s++ + + } + s += l + } + s += 0 + } + } + { + l := uint64(len(d.Sa)) + + { + + t := l + for t >= 0x80 { + t >>= 7 + s++ + } + s++ + + } + + for k0 := range d.Sa { + + { + l := uint64(len(d.Sa[k0])) + + { + + t := l + for t >= 0x80 { + t >>= 7 + s++ + } + s++ + + } + s += l + } + + } + + } + { + if d.Sap != nil { + + { + l := uint64(len((*d.Sap))) + + { + + t := l + for t >= 0x80 { + t >>= 7 + s++ + } + s++ + + } + + for k0 := range *d.Sap { + + { + l := uint64(len((*d.Sap)[k0])) + + { + + t := l + for t >= 0x80 { + t >>= 7 + s++ + } + s++ + + } + s += l + } + + } + + } + s += 0 + } + } + { + if d.Bp != nil { + + s++ + } + } + { + l := uint64(len(d.Ba)) + + { + + t := l + for t >= 0x80 { + t >>= 7 + s++ + } + s++ + + } + s += l + } + { + if d.Bap != nil { + + { + l := uint64(len((*d.Bap))) + + { + + t := l + for t >= 0x80 { + t >>= 7 + s++ + } + s++ + + } + s += l + } + s += 0 + } + } + s += 35 + return +} + +func (d *GenCodeTestStruct) GenCodeMarshal(buf []byte) ([]byte, error) { //nolint:maintidx + size := d.Size() + { + if uint64(cap(buf)) >= size { + buf = buf[:size] + } else { + buf = make([]byte, size) + } + } + i := uint64(0) + + { + + buf[0+0] = byte(d.I8 >> 0) + + } + { + + buf[0+1] = byte(d.I16 >> 0) + + buf[1+1] = byte(d.I16 >> 8) + + } + { + + buf[0+3] = byte(d.I32 >> 0) + + buf[1+3] = byte(d.I32 >> 8) + + buf[2+3] = byte(d.I32 >> 16) + + buf[3+3] = byte(d.I32 >> 24) + + } + { + + buf[0+7] = byte(d.I64 >> 0) + + buf[1+7] = byte(d.I64 >> 8) + + buf[2+7] = byte(d.I64 >> 16) + + buf[3+7] = byte(d.I64 >> 24) + + buf[4+7] = byte(d.I64 >> 32) + + buf[5+7] = byte(d.I64 >> 40) + + buf[6+7] = byte(d.I64 >> 48) + + buf[7+7] = byte(d.I64 >> 56) + + } + { + + buf[0+15] = byte(d.UI8 >> 0) + + } + { + + buf[0+16] = byte(d.UI16 >> 0) + + buf[1+16] = byte(d.UI16 >> 8) + + } + { + + buf[0+18] = byte(d.UI32 >> 0) + + buf[1+18] = byte(d.UI32 >> 8) + + buf[2+18] = byte(d.UI32 >> 16) + + buf[3+18] = byte(d.UI32 >> 24) + + } + { + + buf[0+22] = byte(d.UI64 >> 0) + + buf[1+22] = byte(d.UI64 >> 8) + + buf[2+22] = byte(d.UI64 >> 16) + + buf[3+22] = byte(d.UI64 >> 24) + + buf[4+22] = byte(d.UI64 >> 32) + + buf[5+22] = byte(d.UI64 >> 40) + + buf[6+22] = byte(d.UI64 >> 48) + + buf[7+22] = byte(d.UI64 >> 56) + + } + { + l := uint64(len(d.S)) + + { + + t := uint64(l) + + for t >= 0x80 { + buf[i+30] = byte(t) | 0x80 + t >>= 7 + i++ + } + buf[i+30] = byte(t) + i++ + + } + copy(buf[i+30:], d.S) + i += l + } + { + if d.Sp == nil { + buf[i+30] = 0 + } else { + buf[i+30] = 1 + + { + l := uint64(len((*d.Sp))) + + { + + t := uint64(l) + + for t >= 0x80 { + buf[i+31] = byte(t) | 0x80 + t >>= 7 + i++ + } + buf[i+31] = byte(t) + i++ + + } + copy(buf[i+31:], (*d.Sp)) + i += l + } + i += 0 + } + } + { + l := uint64(len(d.Sa)) + + { + + t := uint64(l) + + for t >= 0x80 { + buf[i+31] = byte(t) | 0x80 + t >>= 7 + i++ + } + buf[i+31] = byte(t) + i++ + + } + for k0 := range d.Sa { + + { + l := uint64(len(d.Sa[k0])) + + { + + t := uint64(l) + + for t >= 0x80 { + buf[i+31] = byte(t) | 0x80 + t >>= 7 + i++ + } + buf[i+31] = byte(t) + i++ + + } + copy(buf[i+31:], d.Sa[k0]) + i += l + } + + } + } + { + if d.Sap == nil { + buf[i+31] = 0 + } else { + buf[i+31] = 1 + + { + l := uint64(len((*d.Sap))) + + { + + t := uint64(l) + + for t >= 0x80 { + buf[i+32] = byte(t) | 0x80 + t >>= 7 + i++ + } + buf[i+32] = byte(t) + i++ + + } + for k0 := range *d.Sap { + + { + l := uint64(len((*d.Sap)[k0])) + + { + + t := uint64(l) + + for t >= 0x80 { + buf[i+32] = byte(t) | 0x80 + t >>= 7 + i++ + } + buf[i+32] = byte(t) + i++ + + } + copy(buf[i+32:], (*d.Sap)[k0]) + i += l + } + + } + } + i += 0 + } + } + { + buf[i+32] = d.B + } + { + if d.Bp == nil { + buf[i+33] = 0 + } else { + buf[i+33] = 1 + + { + buf[i+34] = (*d.Bp) + } + i++ + } + } + { + l := uint64(len(d.Ba)) + + { + + t := uint64(l) + + for t >= 0x80 { + buf[i+34] = byte(t) | 0x80 + t >>= 7 + i++ + } + buf[i+34] = byte(t) + i++ + + } + copy(buf[i+34:], d.Ba) + i += l + } + { + if d.Bap == nil { + buf[i+34] = 0 + } else { + buf[i+34] = 1 + + { + l := uint64(len((*d.Bap))) + + { + + t := uint64(l) + + for t >= 0x80 { + buf[i+35] = byte(t) | 0x80 + t >>= 7 + i++ + } + buf[i+35] = byte(t) + i++ + + } + copy(buf[i+35:], (*d.Bap)) + i += l + } + i += 0 + } + } + return buf[:i+35], nil +} + +func (d *GenCodeTestStruct) GenCodeUnmarshal(buf []byte) (uint64, error) { //nolint:maintidx + i := uint64(0) + + { + + d.I8 = 0 | (int8(buf[i+0+0]) << 0) + + } + { + + d.I16 = 0 | (int16(buf[i+0+1]) << 0) | (int16(buf[i+1+1]) << 8) + + } + { + + d.I32 = 0 | (int32(buf[i+0+3]) << 0) | (int32(buf[i+1+3]) << 8) | (int32(buf[i+2+3]) << 16) | (int32(buf[i+3+3]) << 24) + + } + { + + d.I64 = 0 | (int64(buf[i+0+7]) << 0) | (int64(buf[i+1+7]) << 8) | (int64(buf[i+2+7]) << 16) | (int64(buf[i+3+7]) << 24) | (int64(buf[i+4+7]) << 32) | (int64(buf[i+5+7]) << 40) | (int64(buf[i+6+7]) << 48) | (int64(buf[i+7+7]) << 56) + + } + { + + d.UI8 = 0 | (uint8(buf[i+0+15]) << 0) + + } + { + + d.UI16 = 0 | (uint16(buf[i+0+16]) << 0) | (uint16(buf[i+1+16]) << 8) + + } + { + + d.UI32 = 0 | (uint32(buf[i+0+18]) << 0) | (uint32(buf[i+1+18]) << 8) | (uint32(buf[i+2+18]) << 16) | (uint32(buf[i+3+18]) << 24) + + } + { + + d.UI64 = 0 | (uint64(buf[i+0+22]) << 0) | (uint64(buf[i+1+22]) << 8) | (uint64(buf[i+2+22]) << 16) | (uint64(buf[i+3+22]) << 24) | (uint64(buf[i+4+22]) << 32) | (uint64(buf[i+5+22]) << 40) | (uint64(buf[i+6+22]) << 48) | (uint64(buf[i+7+22]) << 56) + + } + { + l := uint64(0) + + { + + bs := uint8(7) + t := uint64(buf[i+30] & 0x7F) + for buf[i+30]&0x80 == 0x80 { + i++ + t |= uint64(buf[i+30]&0x7F) << bs + bs += 7 + } + i++ + + l = t + + } + d.S = string(buf[i+30 : i+30+l]) + i += l + } + { + if buf[i+30] == 1 { + if d.Sp == nil { + d.Sp = new(string) + } + + { + l := uint64(0) + + { + + bs := uint8(7) + t := uint64(buf[i+31] & 0x7F) + for buf[i+31]&0x80 == 0x80 { + i++ + t |= uint64(buf[i+31]&0x7F) << bs + bs += 7 + } + i++ + + l = t + + } + (*d.Sp) = string(buf[i+31 : i+31+l]) + i += l + } + i += 0 + } else { + d.Sp = nil + } + } + { + l := uint64(0) + + { + + bs := uint8(7) + t := uint64(buf[i+31] & 0x7F) + for buf[i+31]&0x80 == 0x80 { + i++ + t |= uint64(buf[i+31]&0x7F) << bs + bs += 7 + } + i++ + + l = t + + } + if uint64(cap(d.Sa)) >= l { + d.Sa = d.Sa[:l] + } else { + d.Sa = make([]string, l) + } + for k0 := range d.Sa { + + { + l := uint64(0) + + { + + bs := uint8(7) + t := uint64(buf[i+31] & 0x7F) + for buf[i+31]&0x80 == 0x80 { + i++ + t |= uint64(buf[i+31]&0x7F) << bs + bs += 7 + } + i++ + + l = t + + } + d.Sa[k0] = string(buf[i+31 : i+31+l]) + i += l + } + + } + } + { + if buf[i+31] == 1 { + if d.Sap == nil { + d.Sap = new([]string) + } + + { + l := uint64(0) + + { + + bs := uint8(7) + t := uint64(buf[i+32] & 0x7F) + for buf[i+32]&0x80 == 0x80 { + i++ + t |= uint64(buf[i+32]&0x7F) << bs + bs += 7 + } + i++ + + l = t + + } + if uint64(cap((*d.Sap))) >= l { + (*d.Sap) = (*d.Sap)[:l] + } else { + (*d.Sap) = make([]string, l) + } + for k0 := range *d.Sap { + + { + l := uint64(0) + + { + + bs := uint8(7) + t := uint64(buf[i+32] & 0x7F) + for buf[i+32]&0x80 == 0x80 { + i++ + t |= uint64(buf[i+32]&0x7F) << bs + bs += 7 + } + i++ + + l = t + + } + (*d.Sap)[k0] = string(buf[i+32 : i+32+l]) + i += l + } + + } + } + i += 0 + } else { + d.Sap = nil + } + } + { + d.B = buf[i+32] + } + { + if buf[i+33] == 1 { + if d.Bp == nil { + d.Bp = new(byte) + } + + { + (*d.Bp) = buf[i+34] + } + i++ + } else { + d.Bp = nil + } + } + { + l := uint64(0) + + { + + bs := uint8(7) + t := uint64(buf[i+34] & 0x7F) + for buf[i+34]&0x80 == 0x80 { + i++ + t |= uint64(buf[i+34]&0x7F) << bs + bs += 7 + } + i++ + + l = t + + } + if uint64(cap(d.Ba)) >= l { + d.Ba = d.Ba[:l] + } else { + d.Ba = make([]byte, l) + } + copy(d.Ba, buf[i+34:]) + i += l + } + { + if buf[i+34] == 1 { + if d.Bap == nil { + d.Bap = new([]byte) + } + + { + l := uint64(0) + + { + + bs := uint8(7) + t := uint64(buf[i+35] & 0x7F) + for buf[i+35]&0x80 == 0x80 { + i++ + t |= uint64(buf[i+35]&0x7F) << bs + bs += 7 + } + i++ + + l = t + + } + if uint64(cap((*d.Bap))) >= l { + (*d.Bap) = (*d.Bap)[:l] + } else { + (*d.Bap) = make([]byte, l) + } + copy((*d.Bap), buf[i+35:]) + i += l + } + i += 0 + } else { + d.Bap = nil + } + } + return i + 35, nil +} diff --git a/base/formats/dsd/http.go b/base/formats/dsd/http.go new file mode 100644 index 000000000..85aab163a --- /dev/null +++ b/base/formats/dsd/http.go @@ -0,0 +1,178 @@ +package dsd + +import ( + "bytes" + "errors" + "fmt" + "io" + "net/http" + "strings" +) + +// HTTP Related Errors. +var ( + ErrMissingBody = errors.New("dsd: missing http body") + ErrMissingContentType = errors.New("dsd: missing http content type") +) + +const ( + httpHeaderContentType = "Content-Type" +) + +// LoadFromHTTPRequest loads the data from the body into the given interface. +func LoadFromHTTPRequest(r *http.Request, t interface{}) (format uint8, err error) { + return loadFromHTTP(r.Body, r.Header.Get(httpHeaderContentType), t) +} + +// LoadFromHTTPResponse loads the data from the body into the given interface. +// Closing the body is left to the caller. +func LoadFromHTTPResponse(resp *http.Response, t interface{}) (format uint8, err error) { + return loadFromHTTP(resp.Body, resp.Header.Get(httpHeaderContentType), t) +} + +func loadFromHTTP(body io.Reader, mimeType string, t interface{}) (format uint8, err error) { + // Read full body. + data, err := io.ReadAll(body) + if err != nil { + return 0, fmt.Errorf("dsd: failed to read http body: %w", err) + } + + // Load depending on mime type. + return MimeLoad(data, mimeType, t) +} + +// RequestHTTPResponseFormat sets the Accept header to the given format. +func RequestHTTPResponseFormat(r *http.Request, format uint8) (mimeType string, err error) { + // Get mime type. + mimeType, ok := FormatToMimeType[format] + if !ok { + return "", ErrIncompatibleFormat + } + + // Request response format. + r.Header.Set("Accept", mimeType) + + return mimeType, nil +} + +// DumpToHTTPRequest dumps the given data to the HTTP request using the given +// format. It also sets the Accept header to the same format. +func DumpToHTTPRequest(r *http.Request, t interface{}, format uint8) error { + // Get mime type and set request format. + mimeType, err := RequestHTTPResponseFormat(r, format) + if err != nil { + return err + } + + // Serialize data. + data, err := dumpWithoutIdentifier(t, format, "") + if err != nil { + return fmt.Errorf("dsd: failed to serialize: %w", err) + } + + // Add data to request. + r.Header.Set("Content-Type", mimeType) + r.Body = io.NopCloser(bytes.NewReader(data)) + + return nil +} + +// DumpToHTTPResponse dumpts the given data to the HTTP response, using the +// format defined in the request's Accept header. +func DumpToHTTPResponse(w http.ResponseWriter, r *http.Request, t interface{}) error { + // Serialize data based on accept header. + data, mimeType, _, err := MimeDump(t, r.Header.Get("Accept")) + if err != nil { + return fmt.Errorf("dsd: failed to serialize: %w", err) + } + + // Write data to response + w.Header().Set("Content-Type", mimeType) + _, err = w.Write(data) + if err != nil { + return fmt.Errorf("dsd: failed to write response: %w", err) + } + return nil +} + +// MimeLoad loads the given data into the interface based on the given mime type accept header. +func MimeLoad(data []byte, accept string, t interface{}) (format uint8, err error) { + // Find format. + format = FormatFromAccept(accept) + if format == 0 { + return 0, ErrIncompatibleFormat + } + + // Load data. + err = LoadAsFormat(data, format, t) + return format, err +} + +// MimeDump dumps the given interface based on the given mime type accept header. +func MimeDump(t any, accept string) (data []byte, mimeType string, format uint8, err error) { + // Find format. + format = FormatFromAccept(accept) + if format == AUTO { + return nil, "", 0, ErrIncompatibleFormat + } + + // Serialize and return. + data, err = dumpWithoutIdentifier(t, format, "") + return data, mimeType, format, err +} + +// FormatFromAccept returns the format for the given accept definition. +// The accept parameter matches the format of the HTTP Accept header. +// Special cases, in this order: +// - If accept is an empty string: returns default serialization format. +// - If accept contains no supported format, but a wildcard: returns default serialization format. +// - If accept contains no supported format, and no wildcard: returns AUTO format. +func FormatFromAccept(accept string) (format uint8) { + if accept == "" { + return DefaultSerializationFormat + } + + var foundWildcard bool + for _, mimeType := range strings.Split(accept, ",") { + // Clean mime type. + mimeType = strings.TrimSpace(mimeType) + mimeType, _, _ = strings.Cut(mimeType, ";") + if strings.Contains(mimeType, "/") { + _, mimeType, _ = strings.Cut(mimeType, "/") + } + mimeType = strings.ToLower(mimeType) + + // Check if mime type is supported. + format, ok := MimeTypeToFormat[mimeType] + if ok { + return format + } + + // Return default mime type as fallback if any mimetype is okay. + if mimeType == "*" { + foundWildcard = true + } + } + + if foundWildcard { + return DefaultSerializationFormat + } + return AUTO +} + +// Format and MimeType mappings. +var ( + FormatToMimeType = map[uint8]string{ + CBOR: "application/cbor", + JSON: "application/json", + MsgPack: "application/msgpack", + YAML: "application/yaml", + } + MimeTypeToFormat = map[string]uint8{ + "cbor": CBOR, + "json": JSON, + "msgpack": MsgPack, + "yaml": YAML, + "yml": YAML, + } +) diff --git a/base/formats/dsd/http_test.go b/base/formats/dsd/http_test.go new file mode 100644 index 000000000..32651ac84 --- /dev/null +++ b/base/formats/dsd/http_test.go @@ -0,0 +1,45 @@ +package dsd + +import ( + "mime" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMimeTypes(t *testing.T) { + t.Parallel() + + // Test static maps. + for _, mimeType := range FormatToMimeType { + cleaned, _, err := mime.ParseMediaType(mimeType) + assert.NoError(t, err, "mime type must be parse-able") + assert.Equal(t, mimeType, cleaned, "mime type should be clean in map already") + } + for mimeType := range MimeTypeToFormat { + cleaned, _, err := mime.ParseMediaType(mimeType) + assert.NoError(t, err, "mime type must be parse-able") + assert.Equal(t, mimeType, cleaned, "mime type should be clean in map already") + } + + // Test assumptions. + for accept, format := range map[string]uint8{ + "application/json, image/webp": JSON, + "image/webp, application/json": JSON, + "application/json;q=0.9, image/webp": JSON, + "*": DefaultSerializationFormat, + "*/*": DefaultSerializationFormat, + "text/yAMl": YAML, + " * , yaml ": YAML, + "yaml;charset ,*": YAML, + "xml,*": DefaultSerializationFormat, + "text/xml, text/other": AUTO, + "text/*": DefaultSerializationFormat, + "yaml ;charset": AUTO, // Invalid mimetype format. + "": DefaultSerializationFormat, + "x": AUTO, + } { + derivedFormat := FormatFromAccept(accept) + assert.Equal(t, format, derivedFormat, "assumption for %q should hold", accept) + } +} diff --git a/base/formats/dsd/interfaces.go b/base/formats/dsd/interfaces.go new file mode 100644 index 000000000..cae605241 --- /dev/null +++ b/base/formats/dsd/interfaces.go @@ -0,0 +1,9 @@ +package dsd + +// GenCodeCompatible is an interface to identify and use gencode compatible structs. +type GenCodeCompatible interface { + // GenCodeMarshal gencode marshalls the struct into the given byte array, or a new one if its too small. + GenCodeMarshal(buf []byte) ([]byte, error) + // GenCodeUnmarshal gencode unmarshalls the struct and returns the bytes read. + GenCodeUnmarshal(buf []byte) (uint64, error) +} diff --git a/base/formats/dsd/tests.gencode b/base/formats/dsd/tests.gencode new file mode 100644 index 000000000..bc29f5d36 --- /dev/null +++ b/base/formats/dsd/tests.gencode @@ -0,0 +1,23 @@ +struct SimpleTestStruct { + S string + B byte +} + +struct GenCodeTestStructure { + I8 int8 + I16 int16 + I32 int32 + I64 int64 + UI8 uint8 + UI16 uint16 + UI32 uint32 + UI64 uint64 + S string + Sp *string + Sa []string + Sap *[]string + B byte + Bp *byte + Ba []byte + Bap *[]byte +} diff --git a/base/formats/varint/helpers.go b/base/formats/varint/helpers.go new file mode 100644 index 000000000..0aa2c8154 --- /dev/null +++ b/base/formats/varint/helpers.go @@ -0,0 +1,48 @@ +package varint + +import "errors" + +// PrependLength prepends the varint encoded length of the byte slice to itself. +func PrependLength(data []byte) []byte { + return append(Pack64(uint64(len(data))), data...) +} + +// GetNextBlock extract the integer from the beginning of the given byte slice and returns the remaining bytes, the extracted integer, and whether there was an error. +func GetNextBlock(data []byte) ([]byte, int, error) { + l, n, err := Unpack64(data) + if err != nil { + return nil, 0, err + } + length := int(l) + totalLength := length + n + if totalLength > len(data) { + return nil, 0, errors.New("varint: not enough data for given block length") + } + return data[n:totalLength], totalLength, nil +} + +// EncodedSize returns the size required to varint-encode an uint. +func EncodedSize(n uint64) (size int) { + switch { + case n < 1<<7: // < 128 + return 1 + case n < 1<<14: // < 16384 + return 2 + case n < 1<<21: // < 2097152 + return 3 + case n < 1<<28: // < 268435456 + return 4 + case n < 1<<35: // < 34359738368 + return 5 + case n < 1<<42: // < 4398046511104 + return 6 + case n < 1<<49: // < 562949953421312 + return 7 + case n < 1<<56: // < 72057594037927936 + return 8 + case n < 1<<63: // < 9223372036854775808 + return 9 + default: + return 10 + } +} diff --git a/base/formats/varint/varint.go b/base/formats/varint/varint.go new file mode 100644 index 000000000..05880e09d --- /dev/null +++ b/base/formats/varint/varint.go @@ -0,0 +1,97 @@ +package varint + +import ( + "encoding/binary" + "errors" +) + +// ErrBufTooSmall is returned when there is not enough data for parsing a varint. +var ErrBufTooSmall = errors.New("varint: buf too small") + +// Pack8 packs a uint8 into a VarInt. +func Pack8(n uint8) []byte { + if n < 128 { + return []byte{n} + } + return []byte{n, 0x01} +} + +// Pack16 packs a uint16 into a VarInt. +func Pack16(n uint16) []byte { + buf := make([]byte, 3) + w := binary.PutUvarint(buf, uint64(n)) + return buf[:w] +} + +// Pack32 packs a uint32 into a VarInt. +func Pack32(n uint32) []byte { + buf := make([]byte, 5) + w := binary.PutUvarint(buf, uint64(n)) + return buf[:w] +} + +// Pack64 packs a uint64 into a VarInt. +func Pack64(n uint64) []byte { + buf := make([]byte, 10) + w := binary.PutUvarint(buf, n) + return buf[:w] +} + +// Unpack8 unpacks a VarInt into a uint8. It returns the extracted int, how many bytes were used and an error. +func Unpack8(blob []byte) (uint8, int, error) { + if len(blob) < 1 { + return 0, 0, ErrBufTooSmall + } + if blob[0] < 128 { + return blob[0], 1, nil + } + if len(blob) < 2 { + return 0, 0, ErrBufTooSmall + } + if blob[1] != 0x01 { + return 0, 0, errors.New("varint: encoded integer greater than 255 (uint8)") + } + return blob[0], 1, nil +} + +// Unpack16 unpacks a VarInt into a uint16. It returns the extracted int, how many bytes were used and an error. +func Unpack16(blob []byte) (uint16, int, error) { + n, r := binary.Uvarint(blob) + if r == 0 { + return 0, 0, ErrBufTooSmall + } + if r < 0 { + return 0, 0, errors.New("varint: encoded integer greater than 18446744073709551615 (uint64)") + } + if n > 65535 { + return 0, 0, errors.New("varint: encoded integer greater than 65535 (uint16)") + } + return uint16(n), r, nil +} + +// Unpack32 unpacks a VarInt into a uint32. It returns the extracted int, how many bytes were used and an error. +func Unpack32(blob []byte) (uint32, int, error) { + n, r := binary.Uvarint(blob) + if r == 0 { + return 0, 0, ErrBufTooSmall + } + if r < 0 { + return 0, 0, errors.New("varint: encoded integer greater than 18446744073709551615 (uint64)") + } + if n > 4294967295 { + return 0, 0, errors.New("varint: encoded integer greater than 4294967295 (uint32)") + } + return uint32(n), r, nil +} + +// Unpack64 unpacks a VarInt into a uint64. It returns the extracted int, how many bytes were used and an error. +func Unpack64(blob []byte) (uint64, int, error) { + n, r := binary.Uvarint(blob) + if r == 0 { + return 0, 0, ErrBufTooSmall + } + if r < 0 { + return 0, 0, errors.New("varint: encoded integer greater than 18446744073709551615 (uint64)") + } + return n, r, nil +} diff --git a/base/formats/varint/varint_test.go b/base/formats/varint/varint_test.go new file mode 100644 index 000000000..9f2250ef7 --- /dev/null +++ b/base/formats/varint/varint_test.go @@ -0,0 +1,141 @@ +//nolint:gocognit +package varint + +import ( + "bytes" + "testing" +) + +func TestConversion(t *testing.T) { + t.Parallel() + + subjects := []struct { + intType uint8 + bytes []byte + integer uint64 + }{ + {8, []byte{0x00}, 0}, + {8, []byte{0x01}, 1}, + {8, []byte{0x7F}, 127}, + {8, []byte{0x80, 0x01}, 128}, + {8, []byte{0xFF, 0x01}, 255}, + + {16, []byte{0x80, 0x02}, 256}, + {16, []byte{0xFF, 0x7F}, 16383}, + {16, []byte{0x80, 0x80, 0x01}, 16384}, + {16, []byte{0xFF, 0xFF, 0x03}, 65535}, + + {32, []byte{0x80, 0x80, 0x04}, 65536}, + {32, []byte{0xFF, 0xFF, 0x7F}, 2097151}, + {32, []byte{0x80, 0x80, 0x80, 0x01}, 2097152}, + {32, []byte{0xFF, 0xFF, 0xFF, 0x07}, 16777215}, + {32, []byte{0x80, 0x80, 0x80, 0x08}, 16777216}, + {32, []byte{0xFF, 0xFF, 0xFF, 0x7F}, 268435455}, + {32, []byte{0x80, 0x80, 0x80, 0x80, 0x01}, 268435456}, + {32, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0x0F}, 4294967295}, + + {64, []byte{0x80, 0x80, 0x80, 0x80, 0x10}, 4294967296}, + {64, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0x7F}, 34359738367}, + {64, []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x01}, 34359738368}, + {64, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x1F}, 1099511627775}, + {64, []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x20}, 1099511627776}, + {64, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F}, 4398046511103}, + {64, []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01}, 4398046511104}, + {64, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x3F}, 281474976710655}, + {64, []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x40}, 281474976710656}, + {64, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F}, 562949953421311}, + {64, []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01}, 562949953421312}, + {64, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F}, 72057594037927935}, + + {64, []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01}, 72057594037927936}, + {64, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F}, 9223372036854775807}, + + {64, []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01}, 9223372036854775808}, + {64, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01}, 18446744073709551615}, + } + + for _, subject := range subjects { + + actualInteger, _, err := Unpack64(subject.bytes) + if err != nil || actualInteger != subject.integer { + t.Errorf("Unpack64 %d: expected %d, actual %d", subject.bytes, subject.integer, actualInteger) + } + actualBytes := Pack64(subject.integer) + if err != nil || !bytes.Equal(actualBytes, subject.bytes) { + t.Errorf("Pack64 %d: expected %d, actual %d", subject.integer, subject.bytes, actualBytes) + } + + if subject.intType <= 32 { + actualInteger, _, err := Unpack32(subject.bytes) + if err != nil || actualInteger != uint32(subject.integer) { + t.Errorf("Unpack32 %d: expected %d, actual %d", subject.bytes, subject.integer, actualInteger) + } + actualBytes := Pack32(uint32(subject.integer)) + if err != nil || !bytes.Equal(actualBytes, subject.bytes) { + t.Errorf("Pack32 %d: expected %d, actual %d", subject.integer, subject.bytes, actualBytes) + } + } + + if subject.intType <= 16 { + actualInteger, _, err := Unpack16(subject.bytes) + if err != nil || actualInteger != uint16(subject.integer) { + t.Errorf("Unpack16 %d: expected %d, actual %d", subject.bytes, subject.integer, actualInteger) + } + actualBytes := Pack16(uint16(subject.integer)) + if err != nil || !bytes.Equal(actualBytes, subject.bytes) { + t.Errorf("Pack16 %d: expected %d, actual %d", subject.integer, subject.bytes, actualBytes) + } + } + + if subject.intType <= 8 { + actualInteger, _, err := Unpack8(subject.bytes) + if err != nil || actualInteger != uint8(subject.integer) { + t.Errorf("Unpack8 %d: expected %d, actual %d", subject.bytes, subject.integer, actualInteger) + } + actualBytes := Pack8(uint8(subject.integer)) + if err != nil || !bytes.Equal(actualBytes, subject.bytes) { + t.Errorf("Pack8 %d: expected %d, actual %d", subject.integer, subject.bytes, actualBytes) + } + } + + } +} + +func TestFails(t *testing.T) { + t.Parallel() + + subjects := []struct { + intType uint8 + bytes []byte + }{ + {32, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01}}, + {64, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x02}}, + {64, []byte{0xFF}}, + } + + for _, subject := range subjects { + + if subject.intType == 64 { + _, _, err := Unpack64(subject.bytes) + if err == nil { + t.Errorf("Unpack64 %d: expected error while unpacking.", subject.bytes) + } + } + + _, _, err := Unpack32(subject.bytes) + if err == nil { + t.Errorf("Unpack32 %d: expected error while unpacking.", subject.bytes) + } + + _, _, err = Unpack16(subject.bytes) + if err == nil { + t.Errorf("Unpack16 %d: expected error while unpacking.", subject.bytes) + } + + _, _, err = Unpack8(subject.bytes) + if err == nil { + t.Errorf("Unpack8 %d: expected error while unpacking.", subject.bytes) + } + + } +} diff --git a/base/info/module/flags.go b/base/info/module/flags.go new file mode 100644 index 000000000..f1c8af230 --- /dev/null +++ b/base/info/module/flags.go @@ -0,0 +1,38 @@ +package module + +import ( + "flag" + "fmt" + + "github.com/safing/portmaster/base/info" + "github.com/safing/portmaster/base/modules" +) + +var showVersion bool + +func init() { + modules.Register("info", prep, nil, nil) + + flag.BoolVar(&showVersion, "version", false, "show version and exit") +} + +func prep() error { + err := info.CheckVersion() + if err != nil { + return err + } + + if printVersion() { + return modules.ErrCleanExit + } + return nil +} + +// printVersion prints the version, if requested, and returns if it did so. +func printVersion() (printed bool) { + if showVersion { + fmt.Println(info.FullVersion()) + return true + } + return false +} diff --git a/base/info/version.go b/base/info/version.go new file mode 100644 index 000000000..91bad0925 --- /dev/null +++ b/base/info/version.go @@ -0,0 +1,169 @@ +package info + +import ( + "errors" + "fmt" + "os" + "runtime" + "runtime/debug" + "strings" + "sync" +) + +var ( + name string + license string + + version = "dev build" + versionNumber = "0.0.0" + buildSource = "unknown" + buildTime = "unknown" + + info *Info + loadInfo sync.Once +) + +func init() { + // Replace space placeholders. + buildSource = strings.ReplaceAll(buildSource, "_", " ") + buildTime = strings.ReplaceAll(buildTime, "_", " ") + + // Convert version string from git tag to expected format. + version = strings.TrimSpace(strings.ReplaceAll(strings.TrimPrefix(version, "v"), "_", " ")) + versionNumber = strings.TrimSpace(strings.TrimSuffix(version, "dev build")) + if versionNumber == "" { + versionNumber = "0.0.0" + } + + // Get build info. + buildInfo, _ := debug.ReadBuildInfo() + buildSettings := make(map[string]string) + for _, setting := range buildInfo.Settings { + buildSettings[setting.Key] = setting.Value + } + + // Add "dev build" to version if repo is dirty. + if buildSettings["vcs.modified"] == "true" && + !strings.HasSuffix(version, "dev build") { + version += " dev build" + } +} + +// Info holds the programs meta information. +type Info struct { //nolint:maligned + Name string + Version string + VersionNumber string + License string + + Source string + BuildTime string + CGO bool + + Commit string + CommitTime string + Dirty bool + + debug.BuildInfo +} + +// Set sets meta information via the main routine. This should be the first thing your program calls. +func Set(setName string, setVersion string, setLicenseName string) { + name = setName + license = setLicenseName + + if setVersion != "" { + version = setVersion + } +} + +// GetInfo returns all the meta information about the program. +func GetInfo() *Info { + loadInfo.Do(func() { + buildInfo, _ := debug.ReadBuildInfo() + buildSettings := make(map[string]string) + for _, setting := range buildInfo.Settings { + buildSettings[setting.Key] = setting.Value + } + + info = &Info{ + Name: name, + Version: version, + VersionNumber: versionNumber, + License: license, + Source: buildSource, + BuildTime: buildTime, + CGO: buildSettings["CGO_ENABLED"] == "1", + Commit: buildSettings["vcs.revision"], + CommitTime: buildSettings["vcs.time"], + Dirty: buildSettings["vcs.modified"] == "true", + BuildInfo: *buildInfo, + } + + if info.Commit == "" { + info.Commit = "unknown" + } + if info.CommitTime == "" { + info.CommitTime = "unknown" + } + }) + + return info +} + +// Version returns the annotated version. +func Version() string { + return version +} + +// VersionNumber returns the version number only. +func VersionNumber() string { + return versionNumber +} + +// FullVersion returns the full and detailed version string. +func FullVersion() string { + info := GetInfo() + builder := new(strings.Builder) + + // Name and version. + builder.WriteString(fmt.Sprintf("%s %s\n", info.Name, version)) + + // Build info. + cgoInfo := "-cgo" + if info.CGO { + cgoInfo = "+cgo" + } + builder.WriteString(fmt.Sprintf("\nbuilt with %s (%s %s) for %s/%s\n", runtime.Version(), runtime.Compiler, cgoInfo, runtime.GOOS, runtime.GOARCH)) + builder.WriteString(fmt.Sprintf(" at %s\n", info.BuildTime)) + + // Commit info. + dirtyInfo := "clean" + if info.Dirty { + dirtyInfo = "dirty" + } + builder.WriteString(fmt.Sprintf("\ncommit %s (%s)\n", info.Commit, dirtyInfo)) + builder.WriteString(fmt.Sprintf(" at %s\n", info.CommitTime)) + builder.WriteString(fmt.Sprintf(" from %s\n", info.Source)) + + builder.WriteString(fmt.Sprintf("\nLicensed under the %s license.", license)) + + return builder.String() +} + +// CheckVersion checks if the metadata is ok. +func CheckVersion() error { + switch { + case strings.HasSuffix(os.Args[0], ".test"): + return nil // testing on linux/darwin + case strings.HasSuffix(os.Args[0], ".test.exe"): + return nil // testing on windows + default: + // check version information + if name == "" || license == "" { + return errors.New("must call SetInfo() before calling CheckVersion()") + } + } + + return nil +} diff --git a/base/log/flags.go b/base/log/flags.go new file mode 100644 index 000000000..eb0192979 --- /dev/null +++ b/base/log/flags.go @@ -0,0 +1,13 @@ +package log + +import "flag" + +var ( + logLevelFlag string + pkgLogLevelsFlag string +) + +func init() { + flag.StringVar(&logLevelFlag, "log", "", "set log level to [trace|debug|info|warning|error|critical]") + flag.StringVar(&pkgLogLevelsFlag, "plog", "", "set log level of packages: database=trace,notifications=debug") +} diff --git a/base/log/formatting.go b/base/log/formatting.go new file mode 100644 index 000000000..a9bd519d3 --- /dev/null +++ b/base/log/formatting.go @@ -0,0 +1,97 @@ +package log + +import ( + "fmt" + "time" +) + +var counter uint16 + +const ( + maxCount uint16 = 999 + timeFormat string = "060102 15:04:05.000" +) + +func (s Severity) String() string { + switch s { + case TraceLevel: + return "TRAC" + case DebugLevel: + return "DEBU" + case InfoLevel: + return "INFO" + case WarningLevel: + return "WARN" + case ErrorLevel: + return "ERRO" + case CriticalLevel: + return "CRIT" + default: + return "NONE" + } +} + +func formatLine(line *logLine, duplicates uint64, useColor bool) string { + colorStart := "" + colorEnd := "" + if useColor { + colorStart = line.level.color() + colorEnd = endColor() + } + + counter++ + + var fLine string + if line.line == 0 { + fLine = fmt.Sprintf("%s%s ? %s %s %03d%s%s %s", colorStart, line.timestamp.Format(timeFormat), rightArrow, line.level.String(), counter, formatDuplicates(duplicates), colorEnd, line.msg) + } else { + fLen := len(line.file) + fPartStart := fLen - 10 + if fPartStart < 0 { + fPartStart = 0 + } + fLine = fmt.Sprintf("%s%s %s:%03d %s %s %03d%s%s %s", colorStart, line.timestamp.Format(timeFormat), line.file[fPartStart:], line.line, rightArrow, line.level.String(), counter, formatDuplicates(duplicates), colorEnd, line.msg) + } + + if line.tracer != nil { + // append full trace time + if len(line.tracer.logs) > 0 { + fLine += fmt.Sprintf(" Σ=%s", line.timestamp.Sub(line.tracer.logs[0].timestamp)) + } + + // append all trace actions + var d time.Duration + for i, action := range line.tracer.logs { + // set color + if useColor { + colorStart = action.level.color() + } + // set filename length + fLen := len(action.file) + fPartStart := fLen - 10 + if fPartStart < 0 { + fPartStart = 0 + } + // format + if i == len(line.tracer.logs)-1 { // last + d = line.timestamp.Sub(action.timestamp) + } else { + d = line.tracer.logs[i+1].timestamp.Sub(action.timestamp) + } + fLine += fmt.Sprintf("\n%s%19s %s:%03d %s %s%s %s", colorStart, d, action.file[fPartStart:], action.line, rightArrow, action.level.String(), colorEnd, action.msg) + } + } + + if counter >= maxCount { + counter = 0 + } + + return fLine +} + +func formatDuplicates(duplicates uint64) string { + if duplicates == 0 { + return "" + } + return fmt.Sprintf(" [%dx]", duplicates+1) +} diff --git a/base/log/formatting_unix.go b/base/log/formatting_unix.go new file mode 100644 index 000000000..6be6fdcca --- /dev/null +++ b/base/log/formatting_unix.go @@ -0,0 +1,44 @@ +//go:build !windows + +package log + +const ( + rightArrow = "â–¶" + leftArrow = "â—€" +) + +const ( + colorRed = "\033[31m" + colorYellow = "\033[33m" + colorBlue = "\033[34m" + colorMagenta = "\033[35m" + colorCyan = "\033[36m" + + // Saved for later: + // colorBlack = "\033[30m" //. + // colorGreen = "\033[32m" //. + // colorWhite = "\033[37m" //. +) + +func (s Severity) color() string { + switch s { + case DebugLevel: + return colorCyan + case InfoLevel: + return colorBlue + case WarningLevel: + return colorYellow + case ErrorLevel: + return colorRed + case CriticalLevel: + return colorMagenta + case TraceLevel: + return "" + default: + return "" + } +} + +func endColor() string { + return "\033[0m" +} diff --git a/base/log/formatting_windows.go b/base/log/formatting_windows.go new file mode 100644 index 000000000..2c972d0a3 --- /dev/null +++ b/base/log/formatting_windows.go @@ -0,0 +1,56 @@ +package log + +import ( + "github.com/safing/portmaster/base/utils/osdetail" +) + +const ( + rightArrow = ">" + leftArrow = "<" +) + +const ( + // colorBlack = "\033[30m" + colorRed = "\033[31m" + // colorGreen = "\033[32m" + colorYellow = "\033[33m" + colorBlue = "\033[34m" + colorMagenta = "\033[35m" + colorCyan = "\033[36m" + // colorWhite = "\033[37m" +) + +var ( + colorsSupported bool +) + +func init() { + colorsSupported = osdetail.EnableColorSupport() +} + +func (s Severity) color() string { + if colorsSupported { + switch s { + case DebugLevel: + return colorCyan + case InfoLevel: + return colorBlue + case WarningLevel: + return colorYellow + case ErrorLevel: + return colorRed + case CriticalLevel: + return colorMagenta + default: + return "" + } + } + return "" +} + +func endColor() string { + if colorsSupported { + return "\033[0m" + } + return "" +} diff --git a/base/log/input.go b/base/log/input.go new file mode 100644 index 000000000..ef8564a91 --- /dev/null +++ b/base/log/input.go @@ -0,0 +1,219 @@ +package log + +import ( + "fmt" + "runtime" + "strings" + "sync/atomic" + "time" +) + +var ( + warnLogLines = new(uint64) + errLogLines = new(uint64) + critLogLines = new(uint64) +) + +func log(level Severity, msg string, tracer *ContextTracer) { + if !started.IsSet() { + // a bit resource intense, but keeps logs before logging started. + // TODO: create option to disable logging + go func() { + <-startedSignal + log(level, msg, tracer) + }() + return + } + + // get time + now := time.Now() + + // get file and line + _, file, line, ok := runtime.Caller(2) + if !ok { + file = "" + line = 0 + } else { + if len(file) > 3 { + file = file[:len(file)-3] + } else { + file = "" + } + } + + // check if level is enabled for file or generally + if pkgLevelsActive.IsSet() { + pathSegments := strings.Split(file, "/") + if len(pathSegments) < 2 { + // file too short for package levels + return + } + pkgLevelsLock.Lock() + severity, ok := pkgLevels[pathSegments[len(pathSegments)-2]] + pkgLevelsLock.Unlock() + if ok { + if level < severity { + return + } + } else { + // no package level set, check against global level + if uint32(level) < atomic.LoadUint32(logLevel) { + return + } + } + } else if uint32(level) < atomic.LoadUint32(logLevel) { + // no package levels set, check against global level + return + } + + // create log object + log := &logLine{ + msg: msg, + tracer: tracer, + level: level, + timestamp: now, + file: file, + line: line, + } + + // send log to processing + select { + case logBuffer <- log: + default: + forceEmptyingLoop: + // force empty buffer until we can send to it + for { + select { + case forceEmptyingOfBuffer <- struct{}{}: + case logBuffer <- log: + break forceEmptyingLoop + } + } + } + + // wake up writer if necessary + if logsWaitingFlag.SetToIf(false, true) { + select { + case logsWaiting <- struct{}{}: + default: + } + } +} + +func fastcheck(level Severity) bool { + if pkgLevelsActive.IsSet() { + return true + } + if uint32(level) >= atomic.LoadUint32(logLevel) { + return true + } + return false +} + +// Trace is used to log tiny steps. Log traces to context if you can! +func Trace(msg string) { + if fastcheck(TraceLevel) { + log(TraceLevel, msg, nil) + } +} + +// Tracef is used to log tiny steps. Log traces to context if you can! +func Tracef(format string, things ...interface{}) { + if fastcheck(TraceLevel) { + log(TraceLevel, fmt.Sprintf(format, things...), nil) + } +} + +// Debug is used to log minor errors or unexpected events. These occurrences are usually not worth mentioning in itself, but they might hint at a bigger problem. +func Debug(msg string) { + if fastcheck(DebugLevel) { + log(DebugLevel, msg, nil) + } +} + +// Debugf is used to log minor errors or unexpected events. These occurrences are usually not worth mentioning in itself, but they might hint at a bigger problem. +func Debugf(format string, things ...interface{}) { + if fastcheck(DebugLevel) { + log(DebugLevel, fmt.Sprintf(format, things...), nil) + } +} + +// Info is used to log mildly significant events. Should be used to inform about somewhat bigger or user affecting events that happen. +func Info(msg string) { + if fastcheck(InfoLevel) { + log(InfoLevel, msg, nil) + } +} + +// Infof is used to log mildly significant events. Should be used to inform about somewhat bigger or user affecting events that happen. +func Infof(format string, things ...interface{}) { + if fastcheck(InfoLevel) { + log(InfoLevel, fmt.Sprintf(format, things...), nil) + } +} + +// Warning is used to log (potentially) bad events, but nothing broke (even a little) and there is no need to panic yet. +func Warning(msg string) { + atomic.AddUint64(warnLogLines, 1) + if fastcheck(WarningLevel) { + log(WarningLevel, msg, nil) + } +} + +// Warningf is used to log (potentially) bad events, but nothing broke (even a little) and there is no need to panic yet. +func Warningf(format string, things ...interface{}) { + atomic.AddUint64(warnLogLines, 1) + if fastcheck(WarningLevel) { + log(WarningLevel, fmt.Sprintf(format, things...), nil) + } +} + +// Error is used to log errors that break or impair functionality. The task/process may have to be aborted and tried again later. The system is still operational. Maybe User/Admin should be informed. +func Error(msg string) { + atomic.AddUint64(errLogLines, 1) + if fastcheck(ErrorLevel) { + log(ErrorLevel, msg, nil) + } +} + +// Errorf is used to log errors that break or impair functionality. The task/process may have to be aborted and tried again later. The system is still operational. +func Errorf(format string, things ...interface{}) { + atomic.AddUint64(errLogLines, 1) + if fastcheck(ErrorLevel) { + log(ErrorLevel, fmt.Sprintf(format, things...), nil) + } +} + +// Critical is used to log events that completely break the system. Operation cannot continue. User/Admin must be informed. +func Critical(msg string) { + atomic.AddUint64(critLogLines, 1) + if fastcheck(CriticalLevel) { + log(CriticalLevel, msg, nil) + } +} + +// Criticalf is used to log events that completely break the system. Operation cannot continue. User/Admin must be informed. +func Criticalf(format string, things ...interface{}) { + atomic.AddUint64(critLogLines, 1) + if fastcheck(CriticalLevel) { + log(CriticalLevel, fmt.Sprintf(format, things...), nil) + } +} + +// TotalWarningLogLines returns the total amount of warning log lines since +// start of the program. +func TotalWarningLogLines() uint64 { + return atomic.LoadUint64(warnLogLines) +} + +// TotalErrorLogLines returns the total amount of error log lines since start +// of the program. +func TotalErrorLogLines() uint64 { + return atomic.LoadUint64(errLogLines) +} + +// TotalCriticalLogLines returns the total amount of critical log lines since +// start of the program. +func TotalCriticalLogLines() uint64 { + return atomic.LoadUint64(critLogLines) +} diff --git a/base/log/logging.go b/base/log/logging.go new file mode 100644 index 000000000..fe777abad --- /dev/null +++ b/base/log/logging.go @@ -0,0 +1,243 @@ +package log + +import ( + "fmt" + "os" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/tevino/abool" +) + +// concept +/* +- Logging function: + - check if file-based levelling enabled + - if yes, check if level is active on this file + - check if level is active + - send data to backend via big buffered channel +- Backend: + - wait until there is time for writing logs + - write logs + - configurable if logged to folder (buffer + rollingFileAppender) and/or console + - console: log everything above INFO to stderr +- Channel overbuffering protection: + - if buffer is full, trigger write +- Anti-Importing-Loop: + - everything imports logging + - logging is configured by main module and is supplied access to configuration and taskmanager +*/ + +// Severity describes a log level. +type Severity uint32 + +// Message describes a log level message and is implemented +// by logLine. +type Message interface { + Text() string + Severity() Severity + Time() time.Time + File() string + LineNumber() int +} + +type logLine struct { + msg string + tracer *ContextTracer + level Severity + timestamp time.Time + file string + line int +} + +func (ll *logLine) Text() string { + return ll.msg +} + +func (ll *logLine) Severity() Severity { + return ll.level +} + +func (ll *logLine) Time() time.Time { + return ll.timestamp +} + +func (ll *logLine) File() string { + return ll.file +} + +func (ll *logLine) LineNumber() int { + return ll.line +} + +func (ll *logLine) Equal(ol *logLine) bool { + switch { + case ll.msg != ol.msg: + return false + case ll.tracer != nil || ol.tracer != nil: + return false + case ll.file != ol.file: + return false + case ll.line != ol.line: + return false + case ll.level != ol.level: + return false + } + return true +} + +// Log Levels. +const ( + TraceLevel Severity = 1 + DebugLevel Severity = 2 + InfoLevel Severity = 3 + WarningLevel Severity = 4 + ErrorLevel Severity = 5 + CriticalLevel Severity = 6 +) + +var ( + logBuffer chan *logLine + forceEmptyingOfBuffer = make(chan struct{}) + + logLevelInt = uint32(InfoLevel) + logLevel = &logLevelInt + + pkgLevelsActive = abool.NewBool(false) + pkgLevels = make(map[string]Severity) + pkgLevelsLock sync.Mutex + + logsWaiting = make(chan struct{}, 1) + logsWaitingFlag = abool.NewBool(false) + + shutdownFlag = abool.NewBool(false) + shutdownSignal = make(chan struct{}) + shutdownWaitGroup sync.WaitGroup + + initializing = abool.NewBool(false) + started = abool.NewBool(false) + startedSignal = make(chan struct{}) +) + +// SetPkgLevels sets individual log levels for packages. Only effective after Start(). +func SetPkgLevels(levels map[string]Severity) { + pkgLevelsLock.Lock() + pkgLevels = levels + pkgLevelsLock.Unlock() + pkgLevelsActive.Set() +} + +// UnSetPkgLevels removes all individual log levels for packages. +func UnSetPkgLevels() { + pkgLevelsActive.UnSet() +} + +// GetLogLevel returns the current log level. +func GetLogLevel() Severity { + return Severity(atomic.LoadUint32(logLevel)) +} + +// SetLogLevel sets a new log level. Only effective after Start(). +func SetLogLevel(level Severity) { + atomic.StoreUint32(logLevel, uint32(level)) +} + +// Name returns the name of the log level. +func (s Severity) Name() string { + switch s { + case TraceLevel: + return "trace" + case DebugLevel: + return "debug" + case InfoLevel: + return "info" + case WarningLevel: + return "warning" + case ErrorLevel: + return "error" + case CriticalLevel: + return "critical" + default: + return "none" + } +} + +// ParseLevel returns the level severity of a log level name. +func ParseLevel(level string) Severity { + switch strings.ToLower(level) { + case "trace": + return 1 + case "debug": + return 2 + case "info": + return 3 + case "warning": + return 4 + case "error": + return 5 + case "critical": + return 6 + } + return 0 +} + +// Start starts the logging system. Must be called in order to see logs. +func Start() (err error) { + if !initializing.SetToIf(false, true) { + return nil + } + + logBuffer = make(chan *logLine, 1024) + + if logLevelFlag != "" { + initialLogLevel := ParseLevel(logLevelFlag) + if initialLogLevel == 0 { + fmt.Fprintf(os.Stderr, "log warning: invalid log level \"%s\", falling back to level info\n", logLevelFlag) + initialLogLevel = InfoLevel + } + + SetLogLevel(initialLogLevel) + } + + // get and set file loglevels + pkgLogLevels := pkgLogLevelsFlag + if len(pkgLogLevels) > 0 { + newPkgLevels := make(map[string]Severity) + for _, pair := range strings.Split(pkgLogLevels, ",") { + splitted := strings.Split(pair, "=") + if len(splitted) != 2 { + err = fmt.Errorf("log warning: invalid file log level \"%s\", ignoring", pair) + fmt.Fprintf(os.Stderr, "%s\n", err.Error()) + break + } + fileLevel := ParseLevel(splitted[1]) + if fileLevel == 0 { + err = fmt.Errorf("log warning: invalid file log level \"%s\", ignoring", pair) + fmt.Fprintf(os.Stderr, "%s\n", err.Error()) + break + } + newPkgLevels[splitted[0]] = fileLevel + } + SetPkgLevels(newPkgLevels) + } + + if !schedulingEnabled { + close(writeTrigger) + } + startWriter() + + started.Set() + close(startedSignal) + + return err +} + +// Shutdown writes remaining log lines and then stops the log system. +func Shutdown() { + if shutdownFlag.SetToIf(false, true) { + close(shutdownSignal) + } + shutdownWaitGroup.Wait() +} diff --git a/base/log/logging_test.go b/base/log/logging_test.go new file mode 100644 index 000000000..577ee51a5 --- /dev/null +++ b/base/log/logging_test.go @@ -0,0 +1,64 @@ +package log + +import ( + "fmt" + "testing" + "time" +) + +func init() { + err := Start() + if err != nil { + panic(fmt.Sprintf("start failed: %s", err)) + } +} + +func TestLogging(t *testing.T) { + t.Parallel() + + // skip + if testing.Short() { + t.Skip() + } + + // set levels (static random) + SetLogLevel(WarningLevel) + SetLogLevel(InfoLevel) + SetLogLevel(ErrorLevel) + SetLogLevel(DebugLevel) + SetLogLevel(CriticalLevel) + SetLogLevel(TraceLevel) + + // log + Trace("Trace") + Debug("Debug") + Info("Info") + Warning("Warning") + Error("Error") + Critical("Critical") + + // logf + Tracef("Trace %s", "f") + Debugf("Debug %s", "f") + Infof("Info %s", "f") + Warningf("Warning %s", "f") + Errorf("Error %s", "f") + Criticalf("Critical %s", "f") + + // play with levels + SetLogLevel(CriticalLevel) + Warning("Warning") + SetLogLevel(TraceLevel) + + // log invalid level + log(0xFF, "msg", nil) + + // wait logs to be written + time.Sleep(1 * time.Millisecond) + + // just for show + UnSetPkgLevels() + + // do not really shut down, we may need logging for other tests + // ShutdownLogging() +} diff --git a/base/log/output.go b/base/log/output.go new file mode 100644 index 000000000..d8a29a406 --- /dev/null +++ b/base/log/output.go @@ -0,0 +1,289 @@ +package log + +import ( + "fmt" + "os" + "runtime/debug" + "sync" + "time" +) + +type ( + // Adapter is used to write logs. + Adapter interface { + // Write is called for each log message. + Write(msg Message, duplicates uint64) + } + + // AdapterFunc is a convenience type for implementing + // Adapter. + AdapterFunc func(msg Message, duplicates uint64) + + // FormatFunc formats msg into a string. + FormatFunc func(msg Message, duplicates uint64) string + + // SimpleFileAdapter implements Adapter and writes all + // messages to File. + SimpleFileAdapter struct { + Format FormatFunc + File *os.File + } +) + +var ( + // StdoutAdapter is a simple file adapter that writes + // all logs to os.Stdout using a predefined format. + StdoutAdapter = &SimpleFileAdapter{ + File: os.Stdout, + Format: defaultColorFormater, + } + + // StderrAdapter is a simple file adapter that writes + // all logs to os.Stdout using a predefined format. + StderrAdapter = &SimpleFileAdapter{ + File: os.Stderr, + Format: defaultColorFormater, + } +) + +var ( + adapter Adapter = StdoutAdapter + + schedulingEnabled = false + writeTrigger = make(chan struct{}) +) + +// SetAdapter configures the logging adapter to use. +// This must be called before the log package is initialized. +func SetAdapter(a Adapter) { + if initializing.IsSet() || a == nil { + return + } + + adapter = a +} + +// Write implements Adapter and calls fn. +func (fn AdapterFunc) Write(msg Message, duplicates uint64) { + fn(msg, duplicates) +} + +// Write implements Adapter and writes msg the underlying file. +func (fileAdapter *SimpleFileAdapter) Write(msg Message, duplicates uint64) { + fmt.Fprintln(fileAdapter.File, fileAdapter.Format(msg, duplicates)) +} + +// EnableScheduling enables external scheduling of the logger. This will require to manually trigger writes via TriggerWrite whenevery logs should be written. Please note that full buffers will also trigger writing. Must be called before Start() to have an effect. +func EnableScheduling() { + if !initializing.IsSet() { + schedulingEnabled = true + } +} + +// TriggerWriter triggers log output writing. +func TriggerWriter() { + if started.IsSet() && schedulingEnabled { + select { + case writeTrigger <- struct{}{}: + default: + } + } +} + +// TriggerWriterChannel returns the channel to trigger log writing. Returned channel will close if EnableScheduling() is not called correctly. +func TriggerWriterChannel() chan struct{} { + return writeTrigger +} + +func defaultColorFormater(line Message, duplicates uint64) string { + return formatLine(line.(*logLine), duplicates, true) //nolint:forcetypeassert // TODO: improve +} + +func startWriter() { + fmt.Printf("%s%s %s BOF%s\n", InfoLevel.color(), time.Now().Format(timeFormat), rightArrow, endColor()) + + shutdownWaitGroup.Add(1) + go writerManager() +} + +func writerManager() { + defer shutdownWaitGroup.Done() + + for { + err := writer() + if err != nil { + Errorf("log: writer failed: %s", err) + } else { + return + } + } +} + +func writer() (err error) { + defer func() { + // recover from panic + panicVal := recover() + if panicVal != nil { + err = fmt.Errorf("%s", panicVal) + + // write stack to stderr + fmt.Fprintf( + os.Stderr, + `===== Error Report ===== +Message: %s +StackTrace: + +%s +===== End of Report ===== +`, + err, + string(debug.Stack()), + ) + } + }() + + var currentLine *logLine + var duplicates uint64 + + for { + // reset + currentLine = nil + duplicates = 0 + + // wait until logs need to be processed + select { + case <-logsWaiting: // normal process + logsWaitingFlag.UnSet() + case <-forceEmptyingOfBuffer: // log buffer is full! + case <-shutdownSignal: // shutting down + finalizeWriting() + return + } + + // wait for timeslot to log + select { + case <-writeTrigger: // normal process + case <-forceEmptyingOfBuffer: // log buffer is full! + case <-shutdownSignal: // shutting down + finalizeWriting() + return + } + + // write all the logs! + writeLoop: + for { + select { + case nextLine := <-logBuffer: + // first line we process, just assign to currentLine + if currentLine == nil { + currentLine = nextLine + continue writeLoop + } + + // we now have currentLine and nextLine + + // if currentLine and nextLine are equal, do not print, just increase counter and continue + if nextLine.Equal(currentLine) { + duplicates++ + continue writeLoop + } + + // if currentLine and line are _not_ equal, output currentLine + adapter.Write(currentLine, duplicates) + // add to unexpected logs + addUnexpectedLogs(currentLine) + // reset duplicate counter + duplicates = 0 + // set new currentLine + currentLine = nextLine + default: + break writeLoop + } + } + + // write final line + if currentLine != nil { + adapter.Write(currentLine, duplicates) + // add to unexpected logs + addUnexpectedLogs(currentLine) + } + + // back down a little + select { + case <-time.After(10 * time.Millisecond): + case <-shutdownSignal: + finalizeWriting() + return + } + + } +} + +func finalizeWriting() { + for { + select { + case line := <-logBuffer: + adapter.Write(line, 0) + case <-time.After(10 * time.Millisecond): + fmt.Printf("%s%s %s EOF%s\n", InfoLevel.color(), time.Now().Format(timeFormat), leftArrow, endColor()) + return + } + } +} + +// Last Unexpected Logs + +var ( + lastUnexpectedLogs [10]string + lastUnexpectedLogsIndex int + lastUnexpectedLogsLock sync.Mutex +) + +func addUnexpectedLogs(line *logLine) { + // Add main line. + if line.level >= WarningLevel { + addUnexpectedLogLine(line) + return + } + + // Check for unexpected lines in the tracer. + if line.tracer != nil { + for _, traceLine := range line.tracer.logs { + if traceLine.level >= WarningLevel { + // Add full trace. + addUnexpectedLogLine(line) + return + } + } + } +} + +func addUnexpectedLogLine(line *logLine) { + lastUnexpectedLogsLock.Lock() + defer lastUnexpectedLogsLock.Unlock() + + // Format line and add to logs. + lastUnexpectedLogs[lastUnexpectedLogsIndex] = formatLine(line, 0, false) + + // Increase index and wrap back to start. + lastUnexpectedLogsIndex = (lastUnexpectedLogsIndex + 1) % len(lastUnexpectedLogs) +} + +// GetLastUnexpectedLogs returns the last 10 log lines of level Warning an up. +func GetLastUnexpectedLogs() []string { + lastUnexpectedLogsLock.Lock() + defer lastUnexpectedLogsLock.Unlock() + + // Make a copy and return. + logsLen := len(lastUnexpectedLogs) + start := lastUnexpectedLogsIndex + logsCopy := make([]string, 0, logsLen) + // Loop from mid-to-mid. + for i := start; i < start+logsLen; i++ { + if lastUnexpectedLogs[i%logsLen] != "" { + logsCopy = append(logsCopy, lastUnexpectedLogs[i%logsLen]) + } + } + + return logsCopy +} diff --git a/base/log/trace.go b/base/log/trace.go new file mode 100644 index 000000000..640594d4c --- /dev/null +++ b/base/log/trace.go @@ -0,0 +1,280 @@ +package log + +import ( + "context" + "fmt" + "runtime" + "strings" + "sync" + "sync/atomic" + "time" +) + +// ContextTracerKey is the key used for the context key/value storage. +type ContextTracerKey struct{} + +// ContextTracer is attached to a context in order bind logs to a context. +type ContextTracer struct { + sync.Mutex + logs []*logLine +} + +var key = ContextTracerKey{} + +// AddTracer adds a ContextTracer to the returned Context. Will return a nil ContextTracer if logging level is not set to trace. Will return a nil ContextTracer if one already exists. Will return a nil ContextTracer in case of an error. Will return a nil context if nil. +func AddTracer(ctx context.Context) (context.Context, *ContextTracer) { + if ctx != nil && fastcheck(TraceLevel) { + // check pkg levels + if pkgLevelsActive.IsSet() { + // get file + _, file, _, ok := runtime.Caller(1) + if !ok { + // cannot get file, ignore + return ctx, nil + } + + pathSegments := strings.Split(file, "/") + if len(pathSegments) < 2 { + // file too short for package levels + return ctx, nil + } + pkgLevelsLock.Lock() + severity, ok := pkgLevels[pathSegments[len(pathSegments)-2]] + pkgLevelsLock.Unlock() + if ok { + // check against package level + if TraceLevel < severity { + return ctx, nil + } + } else { + // no package level set, check against global level + if uint32(TraceLevel) < atomic.LoadUint32(logLevel) { + return ctx, nil + } + } + } else if uint32(TraceLevel) < atomic.LoadUint32(logLevel) { + // no package levels set, check against global level + return ctx, nil + } + + // check for existing tracer + _, ok := ctx.Value(key).(*ContextTracer) + if !ok { + // add and return new tracer + tracer := &ContextTracer{} + return context.WithValue(ctx, key, tracer), tracer + } + } + return ctx, nil +} + +// Tracer returns the ContextTracer previously added to the given Context. +func Tracer(ctx context.Context) *ContextTracer { + if ctx != nil { + tracer, ok := ctx.Value(key).(*ContextTracer) + if ok { + return tracer + } + } + return nil +} + +// Submit collected logs on the context for further processing/outputting. Does nothing if called on a nil ContextTracer. +func (tracer *ContextTracer) Submit() { + if tracer == nil { + return + } + + if !started.IsSet() { + // a bit resource intense, but keeps logs before logging started. + // TODO: create option to disable logging + go func() { + <-startedSignal + tracer.Submit() + }() + return + } + + if len(tracer.logs) == 0 { + return + } + + // extract last line as main line + mainLine := tracer.logs[len(tracer.logs)-1] + tracer.logs = tracer.logs[:len(tracer.logs)-1] + + // create log object + log := &logLine{ + msg: mainLine.msg, + tracer: tracer, + level: mainLine.level, + timestamp: mainLine.timestamp, + file: mainLine.file, + line: mainLine.line, + } + + // send log to processing + select { + case logBuffer <- log: + default: + forceEmptyingLoop: + // force empty buffer until we can send to it + for { + select { + case forceEmptyingOfBuffer <- struct{}{}: + case logBuffer <- log: + break forceEmptyingLoop + } + } + } + + // wake up writer if necessary + if logsWaitingFlag.SetToIf(false, true) { + logsWaiting <- struct{}{} + } +} + +func (tracer *ContextTracer) log(level Severity, msg string) { + // get file and line + _, file, line, ok := runtime.Caller(2) + if !ok { + file = "" + line = 0 + } else { + if len(file) > 3 { + file = file[:len(file)-3] + } else { + file = "" + } + } + + tracer.Lock() + defer tracer.Unlock() + tracer.logs = append(tracer.logs, &logLine{ + timestamp: time.Now(), + level: level, + msg: msg, + file: file, + line: line, + }) +} + +// Trace is used to log tiny steps. Log traces to context if you can! +func (tracer *ContextTracer) Trace(msg string) { + switch { + case tracer != nil: + tracer.log(TraceLevel, msg) + case fastcheck(TraceLevel): + log(TraceLevel, msg, nil) + } +} + +// Tracef is used to log tiny steps. Log traces to context if you can! +func (tracer *ContextTracer) Tracef(format string, things ...interface{}) { + switch { + case tracer != nil: + tracer.log(TraceLevel, fmt.Sprintf(format, things...)) + case fastcheck(TraceLevel): + log(TraceLevel, fmt.Sprintf(format, things...), nil) + } +} + +// Debug is used to log minor errors or unexpected events. These occurrences are usually not worth mentioning in itself, but they might hint at a bigger problem. +func (tracer *ContextTracer) Debug(msg string) { + switch { + case tracer != nil: + tracer.log(DebugLevel, msg) + case fastcheck(DebugLevel): + log(DebugLevel, msg, nil) + } +} + +// Debugf is used to log minor errors or unexpected events. These occurrences are usually not worth mentioning in itself, but they might hint at a bigger problem. +func (tracer *ContextTracer) Debugf(format string, things ...interface{}) { + switch { + case tracer != nil: + tracer.log(DebugLevel, fmt.Sprintf(format, things...)) + case fastcheck(DebugLevel): + log(DebugLevel, fmt.Sprintf(format, things...), nil) + } +} + +// Info is used to log mildly significant events. Should be used to inform about somewhat bigger or user affecting events that happen. +func (tracer *ContextTracer) Info(msg string) { + switch { + case tracer != nil: + tracer.log(InfoLevel, msg) + case fastcheck(InfoLevel): + log(InfoLevel, msg, nil) + } +} + +// Infof is used to log mildly significant events. Should be used to inform about somewhat bigger or user affecting events that happen. +func (tracer *ContextTracer) Infof(format string, things ...interface{}) { + switch { + case tracer != nil: + tracer.log(InfoLevel, fmt.Sprintf(format, things...)) + case fastcheck(InfoLevel): + log(InfoLevel, fmt.Sprintf(format, things...), nil) + } +} + +// Warning is used to log (potentially) bad events, but nothing broke (even a little) and there is no need to panic yet. +func (tracer *ContextTracer) Warning(msg string) { + switch { + case tracer != nil: + tracer.log(WarningLevel, msg) + case fastcheck(WarningLevel): + log(WarningLevel, msg, nil) + } +} + +// Warningf is used to log (potentially) bad events, but nothing broke (even a little) and there is no need to panic yet. +func (tracer *ContextTracer) Warningf(format string, things ...interface{}) { + switch { + case tracer != nil: + tracer.log(WarningLevel, fmt.Sprintf(format, things...)) + case fastcheck(WarningLevel): + log(WarningLevel, fmt.Sprintf(format, things...), nil) + } +} + +// Error is used to log errors that break or impair functionality. The task/process may have to be aborted and tried again later. The system is still operational. Maybe User/Admin should be informed. +func (tracer *ContextTracer) Error(msg string) { + switch { + case tracer != nil: + tracer.log(ErrorLevel, msg) + case fastcheck(ErrorLevel): + log(ErrorLevel, msg, nil) + } +} + +// Errorf is used to log errors that break or impair functionality. The task/process may have to be aborted and tried again later. The system is still operational. +func (tracer *ContextTracer) Errorf(format string, things ...interface{}) { + switch { + case tracer != nil: + tracer.log(ErrorLevel, fmt.Sprintf(format, things...)) + case fastcheck(ErrorLevel): + log(ErrorLevel, fmt.Sprintf(format, things...), nil) + } +} + +// Critical is used to log events that completely break the system. Operation connot continue. User/Admin must be informed. +func (tracer *ContextTracer) Critical(msg string) { + switch { + case tracer != nil: + tracer.log(CriticalLevel, msg) + case fastcheck(CriticalLevel): + log(CriticalLevel, msg, nil) + } +} + +// Criticalf is used to log events that completely break the system. Operation connot continue. User/Admin must be informed. +func (tracer *ContextTracer) Criticalf(format string, things ...interface{}) { + switch { + case tracer != nil: + tracer.log(CriticalLevel, fmt.Sprintf(format, things...)) + case fastcheck(CriticalLevel): + log(CriticalLevel, fmt.Sprintf(format, things...), nil) + } +} diff --git a/base/log/trace_test.go b/base/log/trace_test.go new file mode 100644 index 000000000..999e0f301 --- /dev/null +++ b/base/log/trace_test.go @@ -0,0 +1,35 @@ +package log + +import ( + "context" + "testing" + "time" +) + +func TestContextTracer(t *testing.T) { + t.Parallel() + + // skip + if testing.Short() { + t.Skip() + } + + ctx, tracer := AddTracer(context.Background()) + _ = Tracer(ctx) + + tracer.Trace("api: request received, checking security") + time.Sleep(1 * time.Millisecond) + tracer.Trace("login: logging in user") + time.Sleep(1 * time.Millisecond) + tracer.Trace("database: fetching requested resources") + time.Sleep(10 * time.Millisecond) + tracer.Warning("database: partial failure") + time.Sleep(10 * time.Microsecond) + tracer.Trace("renderer: rendering output") + time.Sleep(1 * time.Millisecond) + tracer.Trace("api: returning request") + + tracer.Trace("api: completed request") + tracer.Submit() + time.Sleep(100 * time.Millisecond) +} diff --git a/base/metrics/api.go b/base/metrics/api.go new file mode 100644 index 000000000..fee2b8f3a --- /dev/null +++ b/base/metrics/api.go @@ -0,0 +1,158 @@ +package metrics + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "time" + + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/log" +) + +func registerAPI() error { + api.RegisterHandler("/metrics", &metricsAPI{}) + + if err := api.RegisterEndpoint(api.Endpoint{ + Name: "Export Registered Metrics", + Description: "List all registered metrics with their metadata.", + Path: "metrics/list", + Read: api.Dynamic, + BelongsTo: module, + StructFunc: func(ar *api.Request) (any, error) { + return ExportMetrics(ar.AuthToken.Read), nil + }, + }); err != nil { + return err + } + + if err := api.RegisterEndpoint(api.Endpoint{ + Name: "Export Metric Values", + Description: "List all exportable metric values.", + Path: "metrics/values", + Read: api.Dynamic, + Parameters: []api.Parameter{{ + Method: http.MethodGet, + Field: "internal-only", + Description: "Specify to only return metrics with an alternative internal ID.", + }}, + BelongsTo: module, + StructFunc: func(ar *api.Request) (any, error) { + return ExportValues( + ar.AuthToken.Read, + ar.Request.URL.Query().Has("internal-only"), + ), nil + }, + }); err != nil { + return err + } + + return nil +} + +type metricsAPI struct{} + +func (m *metricsAPI) ReadPermission(*http.Request) api.Permission { return api.Dynamic } + +func (m *metricsAPI) WritePermission(*http.Request) api.Permission { return api.NotSupported } + +func (m *metricsAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Get API Request for permission and query. + ar := api.GetAPIRequest(r) + if ar == nil { + http.Error(w, "Missing API Request.", http.StatusInternalServerError) + return + } + + // Get expertise level from query. + expertiseLevel := config.ExpertiseLevelDeveloper + switch ar.Request.URL.Query().Get("level") { + case config.ExpertiseLevelNameUser: + expertiseLevel = config.ExpertiseLevelUser + case config.ExpertiseLevelNameExpert: + expertiseLevel = config.ExpertiseLevelExpert + case config.ExpertiseLevelNameDeveloper: + expertiseLevel = config.ExpertiseLevelDeveloper + } + + w.Header().Set("Content-Type", "text/plain; version=0.0.4; charset=utf-8") + w.WriteHeader(http.StatusOK) + WriteMetrics(w, ar.AuthToken.Read, expertiseLevel) +} + +// WriteMetrics writes all metrics that match the given permission and +// expertiseLevel to the given writer. +func WriteMetrics(w io.Writer, permission api.Permission, expertiseLevel config.ExpertiseLevel) { + registryLock.RLock() + defer registryLock.RUnlock() + + // Write all matching metrics. + for _, metric := range registry { + if permission >= metric.Opts().Permission && + expertiseLevel >= metric.Opts().ExpertiseLevel { + metric.WritePrometheus(w) + } + } +} + +func writeMetricsTo(ctx context.Context, url string) error { + // First, collect metrics into buffer. + buf := &bytes.Buffer{} + WriteMetrics(buf, api.PermitSelf, config.ExpertiseLevelDeveloper) + + // Check if there is something to send. + if buf.Len() == 0 { + log.Debugf("metrics: not pushing metrics, nothing to send") + return nil + } + + // Create request + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, buf) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + // Send. + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer func() { + _ = resp.Body.Close() + }() + + // Check return status. + if resp.StatusCode >= 200 && resp.StatusCode <= 299 { + return nil + } + + // Get and return error. + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf( + "got %s while writing metrics to %s: %s", + resp.Status, + url, + body, + ) +} + +func metricsWriter(ctx context.Context) error { + pushURL := pushOption() + ticker := module.NewSleepyTicker(1*time.Minute, 0) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return nil + case <-ticker.Wait(): + err := writeMetricsTo(ctx, pushURL) + if err != nil { + return err + } + } + } +} diff --git a/base/metrics/config.go b/base/metrics/config.go new file mode 100644 index 000000000..40bba265f --- /dev/null +++ b/base/metrics/config.go @@ -0,0 +1,108 @@ +package metrics + +import ( + "flag" + "os" + "strings" + + "github.com/safing/portmaster/base/config" +) + +// Configuration Keys. +var ( + CfgOptionInstanceKey = "core/metrics/instance" + instanceOption config.StringOption + cfgOptionInstanceOrder = 0 + + CfgOptionCommentKey = "core/metrics/comment" + commentOption config.StringOption + cfgOptionCommentOrder = 0 + + CfgOptionPushKey = "core/metrics/push" + pushOption config.StringOption + cfgOptionPushOrder = 0 + + instanceFlag string + defaultInstance string + commentFlag string + pushFlag string +) + +func init() { + hostname, err := os.Hostname() + if err == nil { + hostname = strings.ReplaceAll(hostname, "-", "") + if prometheusFormat.MatchString(hostname) { + defaultInstance = hostname + } + } + + flag.StringVar(&instanceFlag, "metrics-instance", defaultInstance, "set the default metrics instance label for all metrics") + flag.StringVar(&commentFlag, "metrics-comment", "", "set the default metrics comment label") + flag.StringVar(&pushFlag, "push-metrics", "", "set default URL to push prometheus metrics to") +} + +func prepConfig() error { + err := config.Register(&config.Option{ + Name: "Metrics Instance Name", + Key: CfgOptionInstanceKey, + Description: "Define the prometheus instance label for all exported metrics. Please note that changing the metrics instance name will reset persisted metrics.", + Sensitive: true, + OptType: config.OptTypeString, + ExpertiseLevel: config.ExpertiseLevelExpert, + ReleaseLevel: config.ReleaseLevelStable, + DefaultValue: instanceFlag, + RequiresRestart: true, + Annotations: config.Annotations{ + config.DisplayOrderAnnotation: cfgOptionInstanceOrder, + config.CategoryAnnotation: "Metrics", + }, + ValidationRegex: "^(" + prometheusBaseFormt + ")?$", + }) + if err != nil { + return err + } + instanceOption = config.Concurrent.GetAsString(CfgOptionInstanceKey, instanceFlag) + + err = config.Register(&config.Option{ + Name: "Metrics Comment Label", + Key: CfgOptionCommentKey, + Description: "Define a metrics comment label, which is added to the info metric.", + Sensitive: true, + OptType: config.OptTypeString, + ExpertiseLevel: config.ExpertiseLevelExpert, + ReleaseLevel: config.ReleaseLevelStable, + DefaultValue: commentFlag, + RequiresRestart: true, + Annotations: config.Annotations{ + config.DisplayOrderAnnotation: cfgOptionCommentOrder, + config.CategoryAnnotation: "Metrics", + }, + }) + if err != nil { + return err + } + commentOption = config.Concurrent.GetAsString(CfgOptionCommentKey, commentFlag) + + err = config.Register(&config.Option{ + Name: "Push Prometheus Metrics", + Key: CfgOptionPushKey, + Description: "Push metrics to this URL in the prometheus format.", + Sensitive: true, + OptType: config.OptTypeString, + ExpertiseLevel: config.ExpertiseLevelExpert, + ReleaseLevel: config.ReleaseLevelStable, + DefaultValue: pushFlag, + RequiresRestart: true, + Annotations: config.Annotations{ + config.DisplayOrderAnnotation: cfgOptionPushOrder, + config.CategoryAnnotation: "Metrics", + }, + }) + if err != nil { + return err + } + pushOption = config.Concurrent.GetAsString(CfgOptionPushKey, pushFlag) + + return nil +} diff --git a/base/metrics/metric.go b/base/metrics/metric.go new file mode 100644 index 000000000..d5454d403 --- /dev/null +++ b/base/metrics/metric.go @@ -0,0 +1,165 @@ +package metrics + +import ( + "fmt" + "io" + "regexp" + "sort" + "strings" + + vm "github.com/VictoriaMetrics/metrics" + + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/config" +) + +// PrometheusFormatRequirement is required format defined by prometheus for +// metric and label names. +const ( + prometheusBaseFormt = "[a-zA-Z_][a-zA-Z0-9_]*" + PrometheusFormatRequirement = "^" + prometheusBaseFormt + "$" +) + +var prometheusFormat = regexp.MustCompile(PrometheusFormatRequirement) + +// Metric represents one or more metrics. +type Metric interface { + ID() string + LabeledID() string + Opts() *Options + WritePrometheus(w io.Writer) +} + +type metricBase struct { + Identifier string + Labels map[string]string + LabeledIdentifier string + Options *Options + set *vm.Set +} + +// Options can be used to set advanced metric settings. +type Options struct { + // Name defines an optional human readable name for the metric. + Name string + + // InternalID specifies an alternative internal ID that will be used when + // exposing the metric via the API in a structured format. + InternalID string + + // AlertLimit defines an upper limit that triggers an alert. + AlertLimit float64 + + // AlertTimeframe defines an optional timeframe in seconds for which the + // AlertLimit should be interpreted in. + AlertTimeframe float64 + + // Permission defines the permission that is required to read the metric. + Permission api.Permission + + // ExpertiseLevel defines the expertise level that the metric is meant for. + ExpertiseLevel config.ExpertiseLevel + + // Persist enabled persisting the metric on shutdown and loading the previous + // value at start. This is only supported for counters. + Persist bool +} + +func newMetricBase(id string, labels map[string]string, opts Options) (*metricBase, error) { + // Check formats. + if !prometheusFormat.MatchString(strings.ReplaceAll(id, "/", "_")) { + return nil, fmt.Errorf("metric name %q must match %s", id, PrometheusFormatRequirement) + } + for labelName := range labels { + if !prometheusFormat.MatchString(labelName) { + return nil, fmt.Errorf("metric label name %q must match %s", labelName, PrometheusFormatRequirement) + } + } + + // Check permission. + if opts.Permission < api.PermitAnyone { + // Default to PermitUser. + opts.Permission = api.PermitUser + } + + // Ensure that labels is a map. + if labels == nil { + labels = make(map[string]string) + } + + // Create metric base. + base := &metricBase{ + Identifier: id, + Labels: labels, + Options: &opts, + set: vm.NewSet(), + } + base.LabeledIdentifier = base.buildLabeledID() + return base, nil +} + +// ID returns the given ID of the metric. +func (m *metricBase) ID() string { + return m.Identifier +} + +// LabeledID returns the Prometheus-compatible labeled ID of the metric. +func (m *metricBase) LabeledID() string { + return m.LabeledIdentifier +} + +// Opts returns the metric options. They may not be modified. +func (m *metricBase) Opts() *Options { + return m.Options +} + +// WritePrometheus writes the metric in the prometheus format to the given writer. +func (m *metricBase) WritePrometheus(w io.Writer) { + m.set.WritePrometheus(w) +} + +func (m *metricBase) buildLabeledID() string { + // Because we use the namespace and the global flags here, we need to flag + // them as immutable. + registryLock.Lock() + defer registryLock.Unlock() + firstMetricRegistered = true + + // Build ID from Identifier. + metricID := strings.TrimSpace(strings.ReplaceAll(m.Identifier, "/", "_")) + + // Add namespace to ID. + if metricNamespace != "" { + metricID = metricNamespace + "_" + metricID + } + + // Return now if no labels are defined. + if len(globalLabels) == 0 && len(m.Labels) == 0 { + return metricID + } + + // Add global labels to the custom ones, if they don't exist yet. + for labelName, labelValue := range globalLabels { + if _, ok := m.Labels[labelName]; !ok { + m.Labels[labelName] = labelValue + } + } + + // Render labels into a slice and sort them in order to make the labeled ID + // reproducible. + labels := make([]string, 0, len(m.Labels)) + for labelName, labelValue := range m.Labels { + labels = append(labels, fmt.Sprintf("%s=%q", labelName, labelValue)) + } + sort.Strings(labels) + + // Return fully labaled ID. + return fmt.Sprintf("%s{%s}", metricID, strings.Join(labels, ",")) +} + +// Split metrics into sets, according to the API Auth Levels, which will also correspond to the UI Mode levels. SPN // nodes will also allow public access to metrics with the permission "PermitAnyone". +// Save "life-long" metrics on shutdown and load them at start. +// Generate the correct metric name and labels. +// Expose metrics via http, but also via the runtime DB in order to push metrics to the UI. +// The UI will have to parse the prometheus metrics format and will not be able to immediately present historical data, // but data will have to be built. +// Provide the option to push metrics to a prometheus push gateway, this is especially helpful when gathering data from // loads of SPN nodes. diff --git a/base/metrics/metric_counter.go b/base/metrics/metric_counter.go new file mode 100644 index 000000000..90cf7c64e --- /dev/null +++ b/base/metrics/metric_counter.go @@ -0,0 +1,49 @@ +package metrics + +import ( + vm "github.com/VictoriaMetrics/metrics" +) + +// Counter is a counter metric. +type Counter struct { + *metricBase + *vm.Counter +} + +// NewCounter registers a new counter metric. +func NewCounter(id string, labels map[string]string, opts *Options) (*Counter, error) { + // Ensure that there are options. + if opts == nil { + opts = &Options{} + } + + // Make base. + base, err := newMetricBase(id, labels, *opts) + if err != nil { + return nil, err + } + + // Create metric struct. + m := &Counter{ + metricBase: base, + } + + // Create metric in set + m.Counter = m.set.NewCounter(m.LabeledID()) + + // Register metric. + err = register(m) + if err != nil { + return nil, err + } + + // Load state. + m.loadState() + + return m, nil +} + +// CurrentValue returns the current counter value. +func (c *Counter) CurrentValue() uint64 { + return c.Get() +} diff --git a/base/metrics/metric_counter_fetching.go b/base/metrics/metric_counter_fetching.go new file mode 100644 index 000000000..423d74ad8 --- /dev/null +++ b/base/metrics/metric_counter_fetching.go @@ -0,0 +1,62 @@ +package metrics + +import ( + "fmt" + "io" + + vm "github.com/VictoriaMetrics/metrics" +) + +// FetchingCounter is a counter metric that fetches the values via a function call. +type FetchingCounter struct { + *metricBase + counter *vm.Counter + fetchCnt func() uint64 +} + +// NewFetchingCounter registers a new fetching counter metric. +func NewFetchingCounter(id string, labels map[string]string, fn func() uint64, opts *Options) (*FetchingCounter, error) { + // Check if a fetch function is provided. + if fn == nil { + return nil, fmt.Errorf("%w: no fetch function provided", ErrInvalidOptions) + } + + // Ensure that there are options. + if opts == nil { + opts = &Options{} + } + + // Make base. + base, err := newMetricBase(id, labels, *opts) + if err != nil { + return nil, err + } + + // Create metric struct. + m := &FetchingCounter{ + metricBase: base, + fetchCnt: fn, + } + + // Create metric in set + m.counter = m.set.NewCounter(m.LabeledID()) + + // Register metric. + err = register(m) + if err != nil { + return nil, err + } + + return m, nil +} + +// CurrentValue returns the current counter value. +func (fc *FetchingCounter) CurrentValue() uint64 { + return fc.fetchCnt() +} + +// WritePrometheus writes the metric in the prometheus format to the given writer. +func (fc *FetchingCounter) WritePrometheus(w io.Writer) { + fc.counter.Set(fc.fetchCnt()) + fc.metricBase.set.WritePrometheus(w) +} diff --git a/base/metrics/metric_export.go b/base/metrics/metric_export.go new file mode 100644 index 000000000..1df2750b6 --- /dev/null +++ b/base/metrics/metric_export.go @@ -0,0 +1,89 @@ +package metrics + +import ( + "github.com/safing/portmaster/base/api" +) + +// UIntMetric is an interface for special functions of uint metrics. +type UIntMetric interface { + CurrentValue() uint64 +} + +// FloatMetric is an interface for special functions of float metrics. +type FloatMetric interface { + CurrentValue() float64 +} + +// MetricExport is used to export a metric and its current value. +type MetricExport struct { + Metric + CurrentValue any +} + +// ExportMetrics exports all registered metrics. +func ExportMetrics(requestPermission api.Permission) []*MetricExport { + registryLock.RLock() + defer registryLock.RUnlock() + + export := make([]*MetricExport, 0, len(registry)) + for _, metric := range registry { + // Check permission. + if requestPermission < metric.Opts().Permission { + continue + } + + // Add metric with current value. + export = append(export, &MetricExport{ + Metric: metric, + CurrentValue: getCurrentValue(metric), + }) + } + + return export +} + +// ExportValues exports the values of all supported metrics. +func ExportValues(requestPermission api.Permission, internalOnly bool) map[string]any { + registryLock.RLock() + defer registryLock.RUnlock() + + export := make(map[string]any, len(registry)) + for _, metric := range registry { + // Check permission. + if requestPermission < metric.Opts().Permission { + continue + } + + // Get Value. + v := getCurrentValue(metric) + if v == nil { + continue + } + + // Get ID. + var id string + switch { + case metric.Opts().InternalID != "": + id = metric.Opts().InternalID + case internalOnly: + continue + default: + id = metric.LabeledID() + } + + // Add to export + export[id] = v + } + + return export +} + +func getCurrentValue(metric Metric) any { + if m, ok := metric.(UIntMetric); ok { + return m.CurrentValue() + } + if m, ok := metric.(FloatMetric); ok { + return m.CurrentValue() + } + return nil +} diff --git a/base/metrics/metric_gauge.go b/base/metrics/metric_gauge.go new file mode 100644 index 000000000..6e8ea6eab --- /dev/null +++ b/base/metrics/metric_gauge.go @@ -0,0 +1,46 @@ +package metrics + +import ( + vm "github.com/VictoriaMetrics/metrics" +) + +// Gauge is a gauge metric. +type Gauge struct { + *metricBase + *vm.Gauge +} + +// NewGauge registers a new gauge metric. +func NewGauge(id string, labels map[string]string, fn func() float64, opts *Options) (*Gauge, error) { + // Ensure that there are options. + if opts == nil { + opts = &Options{} + } + + // Make base. + base, err := newMetricBase(id, labels, *opts) + if err != nil { + return nil, err + } + + // Create metric struct. + m := &Gauge{ + metricBase: base, + } + + // Create metric in set + m.Gauge = m.set.NewGauge(m.LabeledID(), fn) + + // Register metric. + err = register(m) + if err != nil { + return nil, err + } + + return m, nil +} + +// CurrentValue returns the current gauge value. +func (g *Gauge) CurrentValue() float64 { + return g.Get() +} diff --git a/base/metrics/metric_histogram.go b/base/metrics/metric_histogram.go new file mode 100644 index 000000000..92c02101a --- /dev/null +++ b/base/metrics/metric_histogram.go @@ -0,0 +1,41 @@ +package metrics + +import ( + vm "github.com/VictoriaMetrics/metrics" +) + +// Histogram is a histogram metric. +type Histogram struct { + *metricBase + *vm.Histogram +} + +// NewHistogram registers a new histogram metric. +func NewHistogram(id string, labels map[string]string, opts *Options) (*Histogram, error) { + // Ensure that there are options. + if opts == nil { + opts = &Options{} + } + + // Make base. + base, err := newMetricBase(id, labels, *opts) + if err != nil { + return nil, err + } + + // Create metric struct. + m := &Histogram{ + metricBase: base, + } + + // Create metric in set + m.Histogram = m.set.NewHistogram(m.LabeledID()) + + // Register metric. + err = register(m) + if err != nil { + return nil, err + } + + return m, nil +} diff --git a/base/metrics/metrics_host.go b/base/metrics/metrics_host.go new file mode 100644 index 000000000..5b632a8df --- /dev/null +++ b/base/metrics/metrics_host.go @@ -0,0 +1,263 @@ +package metrics + +import ( + "runtime" + "sync" + "time" + + "github.com/shirou/gopsutil/disk" + "github.com/shirou/gopsutil/load" + "github.com/shirou/gopsutil/mem" + + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/dataroot" + "github.com/safing/portmaster/base/log" +) + +const hostStatTTL = 1 * time.Second + +func registerHostMetrics() (err error) { + // Register load average metrics. + _, err = NewGauge("host/load/avg/1", nil, getFloat64HostStat(LoadAvg1), &Options{Name: "Host Load Avg 1min", Permission: api.PermitUser}) + if err != nil { + return err + } + _, err = NewGauge("host/load/avg/5", nil, getFloat64HostStat(LoadAvg5), &Options{Name: "Host Load Avg 5min", Permission: api.PermitUser}) + if err != nil { + return err + } + _, err = NewGauge("host/load/avg/15", nil, getFloat64HostStat(LoadAvg15), &Options{Name: "Host Load Avg 15min", Permission: api.PermitUser}) + if err != nil { + return err + } + + // Register memory usage metrics. + _, err = NewGauge("host/mem/total", nil, getUint64HostStat(MemTotal), &Options{Name: "Host Memory Total", Permission: api.PermitUser}) + if err != nil { + return err + } + _, err = NewGauge("host/mem/used", nil, getUint64HostStat(MemUsed), &Options{Name: "Host Memory Used", Permission: api.PermitUser}) + if err != nil { + return err + } + _, err = NewGauge("host/mem/available", nil, getUint64HostStat(MemAvailable), &Options{Name: "Host Memory Available", Permission: api.PermitUser}) + if err != nil { + return err + } + _, err = NewGauge("host/mem/used/percent", nil, getFloat64HostStat(MemUsedPercent), &Options{Name: "Host Memory Used in Percent", Permission: api.PermitUser}) + if err != nil { + return err + } + + // Register disk usage metrics. + _, err = NewGauge("host/disk/total", nil, getUint64HostStat(DiskTotal), &Options{Name: "Host Disk Total", Permission: api.PermitUser}) + if err != nil { + return err + } + _, err = NewGauge("host/disk/used", nil, getUint64HostStat(DiskUsed), &Options{Name: "Host Disk Used", Permission: api.PermitUser}) + if err != nil { + return err + } + _, err = NewGauge("host/disk/free", nil, getUint64HostStat(DiskFree), &Options{Name: "Host Disk Free", Permission: api.PermitUser}) + if err != nil { + return err + } + _, err = NewGauge("host/disk/used/percent", nil, getFloat64HostStat(DiskUsedPercent), &Options{Name: "Host Disk Used in Percent", Permission: api.PermitUser}) + if err != nil { + return err + } + + return nil +} + +func getUint64HostStat(getStat func() (uint64, bool)) func() float64 { + return func() float64 { + val, _ := getStat() + return float64(val) + } +} + +func getFloat64HostStat(getStat func() (float64, bool)) func() float64 { + return func() float64 { + val, _ := getStat() + return val + } +} + +var ( + loadAvg *load.AvgStat + loadAvgExpires time.Time + loadAvgLock sync.Mutex +) + +func getLoadAvg() *load.AvgStat { + loadAvgLock.Lock() + defer loadAvgLock.Unlock() + + // Return cache if still valid. + if time.Now().Before(loadAvgExpires) { + return loadAvg + } + + // Refresh. + var err error + loadAvg, err = load.Avg() + if err != nil { + log.Warningf("metrics: failed to get load avg: %s", err) + loadAvg = nil + } + loadAvgExpires = time.Now().Add(hostStatTTL) + + return loadAvg +} + +// LoadAvg1 returns the 1-minute average system load. +func LoadAvg1() (loadAvg float64, ok bool) { + if stat := getLoadAvg(); stat != nil { + return stat.Load1 / float64(runtime.NumCPU()), true + } + return 0, false +} + +// LoadAvg5 returns the 5-minute average system load. +func LoadAvg5() (loadAvg float64, ok bool) { + if stat := getLoadAvg(); stat != nil { + return stat.Load5 / float64(runtime.NumCPU()), true + } + return 0, false +} + +// LoadAvg15 returns the 15-minute average system load. +func LoadAvg15() (loadAvg float64, ok bool) { + if stat := getLoadAvg(); stat != nil { + return stat.Load15 / float64(runtime.NumCPU()), true + } + return 0, false +} + +var ( + memStat *mem.VirtualMemoryStat + memStatExpires time.Time + memStatLock sync.Mutex +) + +func getMemStat() *mem.VirtualMemoryStat { + memStatLock.Lock() + defer memStatLock.Unlock() + + // Return cache if still valid. + if time.Now().Before(memStatExpires) { + return memStat + } + + // Refresh. + var err error + memStat, err = mem.VirtualMemory() + if err != nil { + log.Warningf("metrics: failed to get load avg: %s", err) + memStat = nil + } + memStatExpires = time.Now().Add(hostStatTTL) + + return memStat +} + +// MemTotal returns the total system memory. +func MemTotal() (total uint64, ok bool) { + if stat := getMemStat(); stat != nil { + return stat.Total, true + } + return 0, false +} + +// MemUsed returns the used system memory. +func MemUsed() (used uint64, ok bool) { + if stat := getMemStat(); stat != nil { + return stat.Used, true + } + return 0, false +} + +// MemAvailable returns the available system memory. +func MemAvailable() (available uint64, ok bool) { + if stat := getMemStat(); stat != nil { + return stat.Available, true + } + return 0, false +} + +// MemUsedPercent returns the percent of used system memory. +func MemUsedPercent() (usedPercent float64, ok bool) { + if stat := getMemStat(); stat != nil { + return stat.UsedPercent, true + } + return 0, false +} + +var ( + diskStat *disk.UsageStat + diskStatExpires time.Time + diskStatLock sync.Mutex +) + +func getDiskStat() *disk.UsageStat { + diskStatLock.Lock() + defer diskStatLock.Unlock() + + // Return cache if still valid. + if time.Now().Before(diskStatExpires) { + return diskStat + } + + // Check if we have a data root. + dataRoot := dataroot.Root() + if dataRoot == nil { + log.Warning("metrics: cannot get disk stats without data root") + diskStat = nil + diskStatExpires = time.Now().Add(hostStatTTL) + return diskStat + } + + // Refresh. + var err error + diskStat, err = disk.Usage(dataRoot.Path) + if err != nil { + log.Warningf("metrics: failed to get load avg: %s", err) + diskStat = nil + } + diskStatExpires = time.Now().Add(hostStatTTL) + + return diskStat +} + +// DiskTotal returns the total disk space (from the program's data root). +func DiskTotal() (total uint64, ok bool) { + if stat := getDiskStat(); stat != nil { + return stat.Total, true + } + return 0, false +} + +// DiskUsed returns the used disk space (from the program's data root). +func DiskUsed() (used uint64, ok bool) { + if stat := getDiskStat(); stat != nil { + return stat.Used, true + } + return 0, false +} + +// DiskFree returns the available disk space (from the program's data root). +func DiskFree() (free uint64, ok bool) { + if stat := getDiskStat(); stat != nil { + return stat.Free, true + } + return 0, false +} + +// DiskUsedPercent returns the percent of used disk space (from the program's data root). +func DiskUsedPercent() (usedPercent float64, ok bool) { + if stat := getDiskStat(); stat != nil { + return stat.UsedPercent, true + } + return 0, false +} diff --git a/base/metrics/metrics_info.go b/base/metrics/metrics_info.go new file mode 100644 index 000000000..8ce77f4dc --- /dev/null +++ b/base/metrics/metrics_info.go @@ -0,0 +1,45 @@ +package metrics + +import ( + "runtime" + "strings" + "sync/atomic" + + "github.com/safing/portmaster/base/info" +) + +var reportedStart atomic.Bool + +func registerInfoMetric() error { + meta := info.GetInfo() + _, err := NewGauge( + "info", + map[string]string{ + "version": checkUnknown(meta.Version), + "commit": checkUnknown(meta.Commit), + "build_date": checkUnknown(meta.BuildTime), + "build_source": checkUnknown(meta.Source), + "go_os": runtime.GOOS, + "go_arch": runtime.GOARCH, + "go_version": runtime.Version(), + "go_compiler": runtime.Compiler, + "comment": commentOption(), + }, + func() float64 { + // Report as 0 the first time in order to detect (re)starts. + if reportedStart.CompareAndSwap(false, true) { + return 0 + } + return 1 + }, + nil, + ) + return err +} + +func checkUnknown(s string) string { + if strings.Contains(s, "unknown") { + return "unknown" + } + return s +} diff --git a/base/metrics/metrics_logs.go b/base/metrics/metrics_logs.go new file mode 100644 index 000000000..6e57a0b9d --- /dev/null +++ b/base/metrics/metrics_logs.go @@ -0,0 +1,49 @@ +package metrics + +import ( + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/log" +) + +func registerLogMetrics() (err error) { + _, err = NewFetchingCounter( + "logs/warning/total", + nil, + log.TotalWarningLogLines, + &Options{ + Name: "Total Warning Log Lines", + Permission: api.PermitUser, + }, + ) + if err != nil { + return err + } + + _, err = NewFetchingCounter( + "logs/error/total", + nil, + log.TotalErrorLogLines, + &Options{ + Name: "Total Error Log Lines", + Permission: api.PermitUser, + }, + ) + if err != nil { + return err + } + + _, err = NewFetchingCounter( + "logs/critical/total", + nil, + log.TotalCriticalLogLines, + &Options{ + Name: "Total Critical Log Lines", + Permission: api.PermitUser, + }, + ) + if err != nil { + return err + } + + return nil +} diff --git a/base/metrics/metrics_runtime.go b/base/metrics/metrics_runtime.go new file mode 100644 index 000000000..88bff9d53 --- /dev/null +++ b/base/metrics/metrics_runtime.go @@ -0,0 +1,98 @@ +package metrics + +import ( + "bufio" + "bytes" + "fmt" + "io" + "strings" + + vm "github.com/VictoriaMetrics/metrics" + + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/log" +) + +func registerRuntimeMetric() error { + runtimeBase, err := newMetricBase("_runtime", nil, Options{ + Name: "Golang Runtime", + Permission: api.PermitAdmin, + ExpertiseLevel: config.ExpertiseLevelDeveloper, + }) + if err != nil { + return err + } + + return register(&runtimeMetrics{ + metricBase: runtimeBase, + }) +} + +type runtimeMetrics struct { + *metricBase +} + +func (r *runtimeMetrics) WritePrometheus(w io.Writer) { + // If there nothing to change, just write directly to w. + if metricNamespace == "" && len(globalLabels) == 0 { + vm.WriteProcessMetrics(w) + return + } + + // Write metrics to buffer. + buf := new(bytes.Buffer) + vm.WriteProcessMetrics(buf) + + // Add namespace and label per line. + scanner := bufio.NewScanner(buf) + scanner.Split(bufio.ScanLines) + for scanner.Scan() { + line := scanner.Text() + + // Add namespace, if set. + if metricNamespace != "" { + line = metricNamespace + "_" + line + } + + // Add global labels, if set. + if len(globalLabels) > 0 { + // Find where to insert. + mergeWithExisting := true + insertAt := strings.Index(line, "{") + 1 + if insertAt <= 0 { + mergeWithExisting = false + insertAt = strings.Index(line, " ") + if insertAt < 0 { + continue + } + } + + // Write new line directly to w. + fmt.Fprint(w, line[:insertAt]) + if !mergeWithExisting { + fmt.Fprint(w, "{") + } + labelsAdded := 0 + for labelKey, labelValue := range globalLabels { + fmt.Fprintf(w, "%s=%q", labelKey, labelValue) + // Add separator if not last label. + labelsAdded++ + if labelsAdded < len(globalLabels) { + fmt.Fprint(w, ", ") + } + } + if mergeWithExisting { + fmt.Fprint(w, ", ") + } else { + fmt.Fprint(w, "}") + } + fmt.Fprintln(w, line[insertAt:]) + } + } + + // Check if there was an error in the scanner. + if scanner.Err() != nil { + log.Warningf("metrics: failed to scan go process metrics: %s", scanner.Err()) + } +} diff --git a/base/metrics/module.go b/base/metrics/module.go new file mode 100644 index 000000000..96ed9563a --- /dev/null +++ b/base/metrics/module.go @@ -0,0 +1,171 @@ +package metrics + +import ( + "errors" + "fmt" + "sort" + "sync" + + "github.com/safing/portmaster/base/modules" +) + +var ( + module *modules.Module + + registry []Metric + registryLock sync.RWMutex + + firstMetricRegistered bool + metricNamespace string + globalLabels = make(map[string]string) + + // ErrAlreadyStarted is returned when an operation is only valid before the + // first metric is registered, and is called after. + ErrAlreadyStarted = errors.New("can only be changed before first metric is registered") + + // ErrAlreadyRegistered is returned when a metric with the same ID is + // registered again. + ErrAlreadyRegistered = errors.New("metric already registered") + + // ErrAlreadySet is returned when a value is already set and cannot be changed. + ErrAlreadySet = errors.New("already set") + + // ErrInvalidOptions is returned when invalid options where provided. + ErrInvalidOptions = errors.New("invalid options") +) + +func init() { + module = modules.Register("metrics", prep, start, stop, "config", "database", "api") +} + +func prep() error { + return prepConfig() +} + +func start() error { + // Add metric instance name as global variable if set. + if instanceOption() != "" { + if err := AddGlobalLabel("instance", instanceOption()); err != nil { + return err + } + } + + if err := registerInfoMetric(); err != nil { + return err + } + + if err := registerRuntimeMetric(); err != nil { + return err + } + + if err := registerHostMetrics(); err != nil { + return err + } + + if err := registerLogMetrics(); err != nil { + return err + } + + if err := registerAPI(); err != nil { + return err + } + + if pushOption() != "" { + module.StartServiceWorker("metric pusher", 0, metricsWriter) + } + + return nil +} + +func stop() error { + // Wait until the metrics pusher is done, as it may have started reporting + // and may report a higher number than we store to disk. For persistent + // metrics it can then happen that the first report is lower than the + // previous report, making prometheus think that all that happened since the + // last report, due to the automatic restart detection. + + // The registry is read locked when writing metrics. + // Write lock the registry to make sure all writes are finished. + registryLock.Lock() + registryLock.Unlock() //nolint:staticcheck + + storePersistentMetrics() + + return nil +} + +func register(m Metric) error { + registryLock.Lock() + defer registryLock.Unlock() + + // Check if metric ID is already registered. + for _, registeredMetric := range registry { + if m.LabeledID() == registeredMetric.LabeledID() { + return ErrAlreadyRegistered + } + if m.Opts().InternalID != "" && + m.Opts().InternalID == registeredMetric.Opts().InternalID { + return fmt.Errorf("%w with this internal ID", ErrAlreadyRegistered) + } + } + + // Add new metric to registry and sort it. + registry = append(registry, m) + sort.Sort(byLabeledID(registry)) + + // Set flag that first metric is now registered. + firstMetricRegistered = true + + if module.Status() < modules.StatusStarting { + return fmt.Errorf("registering metric %q too early", m.ID()) + } + + return nil +} + +// SetNamespace sets the namespace for all metrics. It is prefixed to all +// metric IDs. +// It must be set before any metric is registered. +// Does not affect golang runtime metrics. +func SetNamespace(namespace string) error { + // Lock registry and check if a first metric is already registered. + registryLock.Lock() + defer registryLock.Unlock() + if firstMetricRegistered { + return ErrAlreadyStarted + } + + // Check if the namespace is already set. + if metricNamespace != "" { + return ErrAlreadySet + } + + metricNamespace = namespace + return nil +} + +// AddGlobalLabel adds a global label to all metrics. +// Global labels must be added before any metric is registered. +// Does not affect golang runtime metrics. +func AddGlobalLabel(name, value string) error { + // Lock registry and check if a first metric is already registered. + registryLock.Lock() + defer registryLock.Unlock() + if firstMetricRegistered { + return ErrAlreadyStarted + } + + // Check format. + if !prometheusFormat.MatchString(name) { + return fmt.Errorf("metric label name %q must match %s", name, PrometheusFormatRequirement) + } + + globalLabels[name] = value + return nil +} + +type byLabeledID []Metric + +func (r byLabeledID) Len() int { return len(r) } +func (r byLabeledID) Less(i, j int) bool { return r[i].LabeledID() < r[j].LabeledID() } +func (r byLabeledID) Swap(i, j int) { r[i], r[j] = r[j], r[i] } diff --git a/base/metrics/persistence.go b/base/metrics/persistence.go new file mode 100644 index 000000000..2dba585ff --- /dev/null +++ b/base/metrics/persistence.go @@ -0,0 +1,153 @@ +package metrics + +import ( + "errors" + "fmt" + "sync" + "time" + + "github.com/tevino/abool" + + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/log" +) + +var ( + storage *metricsStorage + storageKey string + storageInit = abool.New() + storageLoaded = abool.New() + + db = database.NewInterface(&database.Options{ + Local: true, + Internal: true, + }) + + // ErrAlreadyInitialized is returned when trying to initialize an option + // more than once or if the time window for initializing is over. + ErrAlreadyInitialized = errors.New("already initialized") +) + +type metricsStorage struct { + sync.Mutex + record.Base + + Start time.Time + Counters map[string]uint64 +} + +// EnableMetricPersistence enables metric persistence for metrics that opted +// for it. They given key is the database key where the metric data will be +// persisted. +// This call also directly loads the stored data from the database. +// The returned error is only about loading the metrics, not about enabling +// persistence. +// May only be called once. +func EnableMetricPersistence(key string) error { + // Check if already initialized. + if !storageInit.SetToIf(false, true) { + return ErrAlreadyInitialized + } + + // Set storage key. + storageKey = key + + // Load metrics from storage. + var err error + storage, err = getMetricsStorage(storageKey) + switch { + case err == nil: + // Continue. + case errors.Is(err, database.ErrNotFound): + return nil + default: + return err + } + storageLoaded.Set() + + // Load saved state for all counter metrics. + registryLock.RLock() + defer registryLock.RUnlock() + + for _, m := range registry { + counter, ok := m.(*Counter) + if ok { + counter.loadState() + } + } + + return nil +} + +func (c *Counter) loadState() { + // Check if we can and should load the state. + if !storageLoaded.IsSet() || !c.Opts().Persist { + return + } + + c.Set(storage.Counters[c.LabeledID()]) +} + +func storePersistentMetrics() { + // Check if persistence is enabled. + if !storageInit.IsSet() || storageKey == "" { + return + } + + // Create new storage. + newStorage := &metricsStorage{ + // TODO: This timestamp should be taken from previous save, if possible. + Start: time.Now(), + Counters: make(map[string]uint64), + } + newStorage.SetKey(storageKey) + // Copy values from previous version. + if storageLoaded.IsSet() { + newStorage.Start = storage.Start + } + + registryLock.RLock() + defer registryLock.RUnlock() + + // Export all counter metrics. + for _, m := range registry { + if m.Opts().Persist { + counter, ok := m.(*Counter) + if ok { + newStorage.Counters[m.LabeledID()] = counter.Get() + } + } + } + + // Save to database. + err := db.Put(newStorage) + if err != nil { + log.Warningf("metrics: failed to save metrics storage to db: %s", err) + } +} + +func getMetricsStorage(key string) (*metricsStorage, error) { + r, err := db.Get(key) + if err != nil { + return nil, err + } + + // unwrap + if r.IsWrapped() { + // only allocate a new struct, if we need it + newStorage := &metricsStorage{} + err = record.Unwrap(r, newStorage) + if err != nil { + return nil, err + } + return newStorage, nil + } + + // or adjust type + newStorage, ok := r.(*metricsStorage) + if !ok { + return nil, fmt.Errorf("record not of type *metricsStorage, but %T", r) + } + return newStorage, nil +} diff --git a/base/metrics/testdata/.gitignore b/base/metrics/testdata/.gitignore new file mode 100644 index 000000000..6320cd248 --- /dev/null +++ b/base/metrics/testdata/.gitignore @@ -0,0 +1 @@ +data \ No newline at end of file diff --git a/base/metrics/testdata/README.md b/base/metrics/testdata/README.md new file mode 100644 index 000000000..e3239f664 --- /dev/null +++ b/base/metrics/testdata/README.md @@ -0,0 +1,4 @@ +# Testing metrics + +You can spin up a test setup for pushing and viewing metrics with `docker-compose up`. +Then use the flag `--push-metrics http://127.0.0.1:8428/api/v1/import/prometheus` to push metrics. diff --git a/base/metrics/testdata/docker-compose.yml b/base/metrics/testdata/docker-compose.yml new file mode 100644 index 000000000..e7d7bdc9b --- /dev/null +++ b/base/metrics/testdata/docker-compose.yml @@ -0,0 +1,36 @@ +version: '3.8' + +networks: + pm-metrics-test-net: + +services: + + victoriametrics: + container_name: pm-metrics-test-victoriametrics + image: victoriametrics/victoria-metrics + command: + - '--storageDataPath=/storage' + ports: + - 8428:8428 + volumes: + - ./data/victoriametrics:/storage + networks: + - pm-metrics-test-net + restart: always + + grafana: + container_name: pm-metrics-test-grafana + image: grafana/grafana + command: + - '--config=/etc/grafana/provisioning/config.ini' + depends_on: + - "victoriametrics" + ports: + - 3000:3000 + volumes: + - ./data/grafana:/var/lib/grafana + - ./grafana:/etc/grafana/provisioning + - ./dashboards:/dashboards + networks: + - pm-metrics-test-net + restart: always diff --git a/base/metrics/testdata/grafana/config.ini b/base/metrics/testdata/grafana/config.ini new file mode 100644 index 000000000..341a31a97 --- /dev/null +++ b/base/metrics/testdata/grafana/config.ini @@ -0,0 +1,10 @@ +[auth] +disable_login_form = true +disable_signout_menu = true + +[auth.basic] +enabled = false + +[auth.anonymous] +enabled = true +org_role = Admin diff --git a/base/metrics/testdata/grafana/dashboards/portmaster.yml b/base/metrics/testdata/grafana/dashboards/portmaster.yml new file mode 100644 index 000000000..42813eba5 --- /dev/null +++ b/base/metrics/testdata/grafana/dashboards/portmaster.yml @@ -0,0 +1,11 @@ +apiVersion: 1 + +providers: + - name: 'Portmaster' + folder: 'Portmaster' + disableDeletion: true + updateIntervalSeconds: 10 + allowUiUpdates: true + options: + path: /dashboards + foldersFromFilesStructure: true diff --git a/base/metrics/testdata/grafana/datasources/datasource.yml b/base/metrics/testdata/grafana/datasources/datasource.yml new file mode 100644 index 000000000..a83316581 --- /dev/null +++ b/base/metrics/testdata/grafana/datasources/datasource.yml @@ -0,0 +1,8 @@ +apiVersion: 1 + +datasources: + - name: VictoriaMetrics + type: prometheus + access: proxy + url: http://pm-metrics-test-victoriametrics:8428 + isDefault: true diff --git a/base/notifications/cleaner.go b/base/notifications/cleaner.go new file mode 100644 index 000000000..62982616f --- /dev/null +++ b/base/notifications/cleaner.go @@ -0,0 +1,51 @@ +package notifications + +import ( + "context" + "time" +) + +func cleaner(ctx context.Context) error { //nolint:unparam // Conforms to worker interface + ticker := module.NewSleepyTicker(1*time.Second, 0) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return nil + case <-ticker.Wait(): + deleteExpiredNotifs() + } + } +} + +func deleteExpiredNotifs() { + // Get a copy of the notification map. + notsCopy := getNotsCopy() + + // Delete all expired notifications. + for _, n := range notsCopy { + if n.isExpired() { + n.delete(true) + } + } +} + +func (n *Notification) isExpired() bool { + n.Lock() + defer n.Unlock() + + return n.Expires > 0 && n.Expires < time.Now().Unix() +} + +func getNotsCopy() []*Notification { + notsLock.RLock() + defer notsLock.RUnlock() + + notsCopy := make([]*Notification, 0, len(nots)) + for _, n := range nots { + notsCopy = append(notsCopy, n) + } + + return notsCopy +} diff --git a/base/notifications/config.go b/base/notifications/config.go new file mode 100644 index 000000000..53d5128db --- /dev/null +++ b/base/notifications/config.go @@ -0,0 +1,32 @@ +package notifications + +import ( + "github.com/safing/portmaster/base/config" +) + +// Configuration Keys. +var ( + CfgUseSystemNotificationsKey = "core/useSystemNotifications" + useSystemNotifications config.BoolOption +) + +func registerConfig() error { + if err := config.Register(&config.Option{ + Name: "Desktop Notifications", + Key: CfgUseSystemNotificationsKey, + Description: "In addition to showing notifications in the Portmaster App, also send them to the Desktop. This requires the Portmaster Notifier to be running.", + OptType: config.OptTypeBool, + ExpertiseLevel: config.ExpertiseLevelUser, + ReleaseLevel: config.ReleaseLevelStable, + DefaultValue: true, // TODO: turn off by default on unsupported systems + Annotations: config.Annotations{ + config.DisplayOrderAnnotation: -15, + config.CategoryAnnotation: "User Interface", + }, + }); err != nil { + return err + } + useSystemNotifications = config.Concurrent.GetAsBool(CfgUseSystemNotificationsKey, true) + + return nil +} diff --git a/base/notifications/database.go b/base/notifications/database.go new file mode 100644 index 000000000..6cbbd86bf --- /dev/null +++ b/base/notifications/database.go @@ -0,0 +1,239 @@ +package notifications + +import ( + "errors" + "fmt" + "strings" + "sync" + + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/iterator" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/database/storage" + "github.com/safing/portmaster/base/log" +) + +var ( + nots = make(map[string]*Notification) + notsLock sync.RWMutex + + dbController *database.Controller +) + +// Storage interface errors. +var ( + ErrInvalidData = errors.New("invalid data, must be a notification object") + ErrInvalidPath = errors.New("invalid path") + ErrNoDelete = errors.New("notifications may not be deleted, they must be handled") +) + +// StorageInterface provices a storage.Interface to the configuration manager. +type StorageInterface struct { + storage.InjectBase +} + +func registerAsDatabase() error { + _, err := database.Register(&database.Database{ + Name: "notifications", + Description: "Notifications", + StorageType: "injected", + }) + if err != nil { + return err + } + + controller, err := database.InjectDatabase("notifications", &StorageInterface{}) + if err != nil { + return err + } + + dbController = controller + return nil +} + +// Get returns a database record. +func (s *StorageInterface) Get(key string) (record.Record, error) { + // Get EventID from key. + if !strings.HasPrefix(key, "all/") { + return nil, storage.ErrNotFound + } + key = strings.TrimPrefix(key, "all/") + + // Get notification from storage. + n, ok := getNotification(key) + if !ok { + return nil, storage.ErrNotFound + } + + return n, nil +} + +func getNotification(eventID string) (n *Notification, ok bool) { + notsLock.RLock() + defer notsLock.RUnlock() + + n, ok = nots[eventID] + return +} + +// Query returns a an iterator for the supplied query. +func (s *StorageInterface) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { + it := iterator.New() + go s.processQuery(q, it) + // TODO: check local and internal + + return it, nil +} + +func (s *StorageInterface) processQuery(q *query.Query, it *iterator.Iterator) { + // Get a copy of the notification map. + notsCopy := getNotsCopy() + + // send all notifications + for _, n := range notsCopy { + if inQuery(n, q) { + select { + case it.Next <- n: + case <-it.Done: + // make sure we don't leak this goroutine if the iterator get's cancelled + return + } + } + } + + it.Finish(nil) +} + +func inQuery(n *Notification, q *query.Query) bool { + n.lock.Lock() + defer n.lock.Unlock() + + switch { + case n.Meta().IsDeleted(): + return false + case !q.MatchesKey(n.DatabaseKey()): + return false + case !q.MatchesRecord(n): + return false + } + + return true +} + +// Put stores a record in the database. +func (s *StorageInterface) Put(r record.Record) (record.Record, error) { + // record is already locked! + key := r.DatabaseKey() + n, err := EnsureNotification(r) + if err != nil { + return nil, ErrInvalidData + } + + // transform key + if strings.HasPrefix(key, "all/") { + key = strings.TrimPrefix(key, "all/") + } else { + return nil, ErrInvalidPath + } + + return applyUpdate(n, key) +} + +func applyUpdate(n *Notification, key string) (*Notification, error) { + // separate goroutine in order to correctly lock notsLock + existing, ok := getNotification(key) + + // ignore if already deleted + if !ok || existing.Meta().IsDeleted() { + // this is a completely new notification + // we pass pushUpdate==false because the storage + // controller will push an update on put anyway. + n.save(false) + return n, nil + } + + // Save when we're finished, if needed. + save := false + defer func() { + if save { + existing.save(false) + } + }() + + existing.Lock() + defer existing.Unlock() + + if existing.State == Executed { + return existing, fmt.Errorf("action already executed") + } + + // check if the notification has been marked as + // "executed externally". + if n.State == Executed { + log.Tracef("notifications: action for %s executed externally", n.EventID) + existing.State = Executed + save = true + + // in case the action has been executed immediately by the + // sender we may need to update the SelectedActionID. + // Though, we guard the assignments with value check + // so partial updates that only change the + // State property do not overwrite existing values. + if n.SelectedActionID != "" { + existing.SelectedActionID = n.SelectedActionID + } + } + + if n.SelectedActionID != "" && existing.State == Active { + log.Tracef("notifications: selected action for %s: %s", n.EventID, n.SelectedActionID) + existing.selectAndExecuteAction(n.SelectedActionID) + save = true + } + + return existing, nil +} + +// Delete deletes a record from the database. +func (s *StorageInterface) Delete(key string) error { + // Get EventID from key. + if !strings.HasPrefix(key, "all/") { + return storage.ErrNotFound + } + key = strings.TrimPrefix(key, "all/") + + // Get notification from storage. + n, ok := getNotification(key) + if !ok { + return storage.ErrNotFound + } + + n.delete(true) + return nil +} + +// ReadOnly returns whether the database is read only. +func (s *StorageInterface) ReadOnly() bool { + return false +} + +// EnsureNotification ensures that the given record is a Notification and returns it. +func EnsureNotification(r record.Record) (*Notification, error) { + // unwrap + if r.IsWrapped() { + // only allocate a new struct, if we need it + n := &Notification{} + err := record.Unwrap(r, n) + if err != nil { + return nil, err + } + return n, nil + } + + // or adjust type + n, ok := r.(*Notification) + if !ok { + return nil, fmt.Errorf("record not of type *Notification, but %T", r) + } + return n, nil +} diff --git a/base/notifications/doc.go b/base/notifications/doc.go new file mode 100644 index 000000000..be1838673 --- /dev/null +++ b/base/notifications/doc.go @@ -0,0 +1,26 @@ +/* +Package notifications provides a notification system. + +# Notification Lifecycle + +1. Create Notification with an ID and Message. +2. Set possible actions and save it. +3. When the user responds, the action is executed. + +Example + + // create notification + n := notifications.New("update-available", "A new update is available. Restart to upgrade.") + // set actions and save + n.AddAction("later", "Later").AddAction("restart", "Restart now!").Save() + + // wait for user action + selectedAction := <-n.Response() + switch selectedAction { + case "later": + log.Infof("user wants to upgrade later.") + case "restart": + log.Infof("user wants to restart now.") + } +*/ +package notifications diff --git a/base/notifications/module-mirror.go b/base/notifications/module-mirror.go new file mode 100644 index 000000000..96173be4d --- /dev/null +++ b/base/notifications/module-mirror.go @@ -0,0 +1,115 @@ +package notifications + +import ( + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" +) + +// AttachToModule attaches the notification to a module and changes to the +// notification will be reflected on the module failure status. +func (n *Notification) AttachToModule(m *modules.Module) { + if m == nil { + log.Warningf("notifications: invalid usage: cannot attach %s to nil module", n.EventID) + return + } + + n.lock.Lock() + defer n.lock.Unlock() + + if n.State != Active { + log.Warningf("notifications: cannot attach module to inactive notification %s", n.EventID) + return + } + if n.belongsTo != nil { + log.Warningf("notifications: cannot override attached module for notification %s", n.EventID) + return + } + + // Attach module. + n.belongsTo = m + + // Set module failure status. + switch n.Type { //nolint:exhaustive + case Info: + m.Hint(n.EventID, n.Title, n.Message) + case Warning: + m.Warning(n.EventID, n.Title, n.Message) + case Error: + m.Error(n.EventID, n.Title, n.Message) + default: + log.Warningf("notifications: incompatible type for attaching to module in notification %s", n.EventID) + m.Error(n.EventID, n.Title, n.Message+" [incompatible notification type]") + } +} + +// resolveModuleFailure removes the notification from the module failure status. +func (n *Notification) resolveModuleFailure() { + if n.belongsTo != nil { + // Resolve failure in attached module. + n.belongsTo.Resolve(n.EventID) + + // Reset attachment in order to mitigate duplicate failure resolving. + // Re-attachment is prevented by the state check when attaching. + n.belongsTo = nil + } +} + +func init() { + modules.SetFailureUpdateNotifyFunc(mirrorModuleStatus) +} + +func mirrorModuleStatus(moduleFailure uint8, id, title, msg string) { + // Ignore "resolve all" requests. + if id == "" { + return + } + + // Get notification from storage. + n, ok := getNotification(id) + if ok { + // The notification already exists. + + // Check if we should delete it. + if moduleFailure == modules.FailureNone && !n.Meta().IsDeleted() { + + // Remove belongsTo, as the deletion was already triggered by the module itself. + n.Lock() + n.belongsTo = nil + n.Unlock() + + n.Delete() + } + + return + } + + // A notification for the given ID does not yet exists, create it. + n = &Notification{ + EventID: id, + Title: title, + Message: msg, + AvailableActions: []*Action{ + { + Text: "Get Help", + Type: ActionTypeOpenURL, + Payload: "https://safing.io/support/", + }, + }, + } + + switch moduleFailure { + case modules.FailureNone: + return + case modules.FailureHint: + n.Type = Info + n.AvailableActions = nil + case modules.FailureWarning: + n.Type = Warning + n.ShowOnSystem = true + case modules.FailureError: + n.Type = Error + n.ShowOnSystem = true + } + + Notify(n) +} diff --git a/base/notifications/module.go b/base/notifications/module.go new file mode 100644 index 000000000..839c522cd --- /dev/null +++ b/base/notifications/module.go @@ -0,0 +1,66 @@ +package notifications + +import ( + "fmt" + "time" + + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/modules" +) + +var module *modules.Module + +func init() { + module = modules.Register("notifications", prep, start, nil, "database", "config", "base") +} + +func prep() error { + return registerConfig() +} + +func start() error { + err := registerAsDatabase() + if err != nil { + return err + } + + showConfigLoadingErrors() + + go module.StartServiceWorker("cleaner", 1*time.Second, cleaner) + return nil +} + +func showConfigLoadingErrors() { + validationErrors := config.GetLoadedConfigValidationErrors() + if len(validationErrors) == 0 { + return + } + + // Trigger a module error for more awareness. + module.Error( + "config:validation-errors-on-load", + "Invalid Settings", + "Some current settings are invalid. Please update them and restart the Portmaster.", + ) + + // Send one notification per invalid setting. + for _, validationError := range config.GetLoadedConfigValidationErrors() { + NotifyError( + fmt.Sprintf("config:validation-error:%s", validationError.Option.Key), + fmt.Sprintf("Invalid Setting for %s", validationError.Option.Name), + fmt.Sprintf(`Your current setting for %s is invalid: %s + +Please update the setting and restart the Portmaster, until then the default value is used.`, + validationError.Option.Name, + validationError.Err.Error(), + ), + Action{ + Text: "Change", + Type: ActionTypeOpenSetting, + Payload: &ActionTypeOpenSettingPayload{ + Key: validationError.Option.Key, + }, + }, + ) + } +} diff --git a/base/notifications/notification.go b/base/notifications/notification.go new file mode 100644 index 000000000..c088f78d7 --- /dev/null +++ b/base/notifications/notification.go @@ -0,0 +1,523 @@ +package notifications + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/utils" +) + +// Type describes the type of a notification. +type Type uint8 + +// Notification types. +const ( + Info Type = 0 + Warning Type = 1 + Prompt Type = 2 + Error Type = 3 +) + +// State describes the state of a notification. +type State string + +// NotificationActionFn defines the function signature for notification action +// functions. +type NotificationActionFn func(context.Context, *Notification) error + +// Possible notification states. +// State transitions can only happen from top to bottom. +const ( + // Active describes a notification that is active, no expired and, + // if actions are available, still waits for the user to select an + // action. + Active State = "active" + // Responded describes a notification where the user has already + // selected which action to take but that action is still to be + // performed. + Responded State = "responded" + // Executes describes a notification where the user has selected + // and action and that action has been performed. + Executed State = "executed" +) + +// Notification represents a notification that is to be delivered to the user. +type Notification struct { //nolint:maligned + record.Base + // EventID is used to identify a specific notification. It consists of + // the module name and a per-module unique event id. + // The following format is recommended: + // : + EventID string + // GUID is a unique identifier for each notification instance. That is + // two notifications with the same EventID must still have unique GUIDs. + // The GUID is mainly used for system (Windows) integration and is + // automatically populated by the notification package. Average users + // don't need to care about this field. + GUID string + // Type is the notification type. It can be one of Info, Warning or Prompt. + Type Type + // Title is an optional and very short title for the message that gives a + // hint about what the notification is about. + Title string + // Category is an optional category for the notification that allows for + // tagging and grouping notifications by category. + Category string + // Message is the default message shown to the user if no localized version + // of the notification is available. Note that the message should already + // have any paramerized values replaced. + Message string + // ShowOnSystem specifies if the notification should be also shown on the + // operating system. Notifications shown on the operating system level are + // more focus-intrusive and should only be used for important notifications. + // If the configuration option "Desktop Notifications" is switched off, this + // will be forced to false on the first save. + ShowOnSystem bool + // EventData contains an additional payload for the notification. This payload + // may contain contextual data and may be used by a localization framework + // to populate the notification message template. + // If EventData implements sync.Locker it will be locked and unlocked together with the + // notification. Otherwise, EventData is expected to be immutable once the + // notification has been saved and handed over to the notification or database package. + EventData interface{} + // Expires holds the unix epoch timestamp at which the notification expires + // and can be cleaned up. + // Users can safely ignore expired notifications and should handle expiry the + // same as deletion. + Expires int64 + // State describes the current state of a notification. See State for + // a list of available values and their meaning. + State State + // AvailableActions defines a list of actions that a user can choose from. + AvailableActions []*Action + // SelectedActionID is updated to match the ID of one of the AvailableActions + // based on the user selection. + SelectedActionID string + + // belongsTo holds the module this notification belongs to. The notification + // lifecycle will be mirrored to the module's failure status. + belongsTo *modules.Module + + lock sync.Mutex + actionFunction NotificationActionFn // call function to process action + actionTrigger chan string // and/or send to a channel + expiredTrigger chan struct{} // closed on expire +} + +// Action describes an action that can be taken for a notification. +type Action struct { + // ID specifies a unique ID for the action. If an action is selected, the ID + // is written to SelectedActionID and the notification is saved. + // If the action type is not ActionTypeNone, the ID may be empty, signifying + // that this action is merely additional and selecting it does not dismiss the + // notification. + ID string + // Text on the button. + Text string + // Type specifies the action type. Implementing interfaces should only + // display action types they can handle. + Type ActionType + // Payload holds additional data for special action types. + Payload interface{} +} + +// ActionType defines a specific type of action. +type ActionType string + +// Action Types. +const ( + ActionTypeNone = "" // Report selected ID back to backend. + ActionTypeOpenURL = "open-url" // Open external URL + ActionTypeOpenPage = "open-page" // Payload: Page ID + ActionTypeOpenSetting = "open-setting" // Payload: See struct definition below. + ActionTypeOpenProfile = "open-profile" // Payload: Scoped Profile ID + ActionTypeInjectEvent = "inject-event" // Payload: Event ID + ActionTypeWebhook = "call-webhook" // Payload: See struct definition below. +) + +// ActionTypeOpenSettingPayload defines the payload for the OpenSetting Action Type. +type ActionTypeOpenSettingPayload struct { + // Key is the key of the setting. + Key string + // Profile is the scoped ID of the profile. + // Leaving this empty opens the global settings. + Profile string +} + +// ActionTypeWebhookPayload defines the payload for the WebhookPayload Action Type. +type ActionTypeWebhookPayload struct { + // HTTP Method to use. Defaults to "GET", or "POST" if a Payload is supplied. + Method string + // URL to call. + // If the URL is relative, prepend the current API endpoint base path. + // If the URL is absolute, send request to the Portmaster. + URL string + // Payload holds arbitrary payload data. + Payload interface{} + // ResultAction defines what should be done with successfully returned data. + // Must one of: + // - `ignore`: do nothing (default) + // - `display`: the result is a human readable message, display it in a success message. + ResultAction string +} + +// Get returns the notification identifed by the given id or nil if it doesn't exist. +func Get(id string) *Notification { + notsLock.RLock() + defer notsLock.RUnlock() + n, ok := nots[id] + if ok { + return n + } + return nil +} + +// Delete deletes the notification with the given id. +func Delete(id string) { + // Delete notification in defer to enable deferred unlocking. + var n *Notification + var ok bool + defer func() { + if ok { + n.Delete() + } + }() + + notsLock.Lock() + defer notsLock.Unlock() + n, ok = nots[id] +} + +// NotifyInfo is a helper method for quickly showing an info notification. +// The notification will be activated immediately. +// If the provided id is empty, an id will derived from msg. +// ShowOnSystem is disabled. +// If no actions are defined, a default "OK" (ID:"ack") action will be added. +func NotifyInfo(id, title, msg string, actions ...Action) *Notification { + return notify(Info, id, title, msg, false, actions...) +} + +// NotifyWarn is a helper method for quickly showing a warning notification +// The notification will be activated immediately. +// If the provided id is empty, an id will derived from msg. +// ShowOnSystem is enabled. +// If no actions are defined, a default "OK" (ID:"ack") action will be added. +func NotifyWarn(id, title, msg string, actions ...Action) *Notification { + return notify(Warning, id, title, msg, true, actions...) +} + +// NotifyError is a helper method for quickly showing an error notification. +// The notification will be activated immediately. +// If the provided id is empty, an id will derived from msg. +// ShowOnSystem is enabled. +// If no actions are defined, a default "OK" (ID:"ack") action will be added. +func NotifyError(id, title, msg string, actions ...Action) *Notification { + return notify(Error, id, title, msg, true, actions...) +} + +// NotifyPrompt is a helper method for quickly showing a prompt notification. +// The notification will be activated immediately. +// If the provided id is empty, an id will derived from msg. +// ShowOnSystem is disabled. +// If no actions are defined, a default "OK" (ID:"ack") action will be added. +func NotifyPrompt(id, title, msg string, actions ...Action) *Notification { + return notify(Prompt, id, title, msg, false, actions...) +} + +func notify(nType Type, id, title, msg string, showOnSystem bool, actions ...Action) *Notification { + // Process actions. + var acts []*Action + if len(actions) == 0 { + // Create ack action if there are no defined actions. + acts = []*Action{ + { + ID: "ack", + Text: "OK", + }, + } + } else { + // Reference given actions for notification. + acts = make([]*Action, len(actions)) + for index := range actions { + a := actions[index] + acts[index] = &a + } + } + + return Notify(&Notification{ + EventID: id, + Type: nType, + Title: title, + Message: msg, + ShowOnSystem: showOnSystem, + AvailableActions: acts, + }) +} + +// Notify sends the given notification. +func Notify(n *Notification) *Notification { + // While this function is very similar to Save(), it is much nicer to use in + // order to just fire off one notification, as it does not require some more + // uncommon Go syntax. + + n.save(true) + return n +} + +// Save saves the notification. +func (n *Notification) Save() { + n.save(true) +} + +// save saves the notification to the internal storage. It locks the +// notification, so it must not be locked when save is called. +func (n *Notification) save(pushUpdate bool) { + var id string + + // Save notification after pre-save processing. + defer func() { + if id != "" { + // Lock and save to notification storage. + notsLock.Lock() + defer notsLock.Unlock() + nots[id] = n + } + }() + + // We do not access EventData here, so it is enough to just lock the + // notification itself. + n.lock.Lock() + defer n.lock.Unlock() + + // Check if required data is present. + if n.Title == "" && n.Message == "" { + log.Warning("notifications: ignoring notification without Title or Message") + return + } + + // Derive EventID from Message if not given. + if n.EventID == "" { + n.EventID = fmt.Sprintf( + "unknown:%s", + utils.DerivedInstanceUUID(n.Message).String(), + ) + } + + // Save ID for deletion + id = n.EventID + + // Generate random GUID if not set. + if n.GUID == "" { + n.GUID = utils.RandomUUID(n.EventID).String() + } + + // Make sure we always have a notification state assigned. + if n.State == "" { + n.State = Active + } + + // Initialize on first save. + if !n.KeyIsSet() { + // Set database key. + n.SetKey(fmt.Sprintf("notifications:all/%s", n.EventID)) + + // Check if notifications should be shown on the system at all. + if !useSystemNotifications() { + n.ShowOnSystem = false + } + } + + // Update meta data. + n.UpdateMeta() + + // Push update via the database system if needed. + if pushUpdate { + log.Tracef("notifications: pushing update for %s to subscribers", n.Key()) + dbController.PushUpdate(n) + } +} + +// SetActionFunction sets a trigger function to be executed when the user reacted on the notification. +// The provided function will be started as its own goroutine and will have to lock everything it accesses, even the provided notification. +func (n *Notification) SetActionFunction(fn NotificationActionFn) *Notification { + n.lock.Lock() + defer n.lock.Unlock() + n.actionFunction = fn + return n +} + +// Response waits for the user to respond to the notification and returns the selected action. +func (n *Notification) Response() <-chan string { + n.lock.Lock() + defer n.lock.Unlock() + + if n.actionTrigger == nil { + n.actionTrigger = make(chan string) + } + + return n.actionTrigger +} + +// Update updates/resends a notification if it was not already responded to. +func (n *Notification) Update(expires int64) { + // Save when we're finished, if needed. + save := false + defer func() { + if save { + n.save(true) + } + }() + + n.lock.Lock() + defer n.lock.Unlock() + + // Don't update if notification isn't active. + if n.State != Active { + return + } + + // Don't update too quickly. + if n.Meta().Modified > time.Now().Add(-10*time.Second).Unix() { + return + } + + // Update expiry and save. + n.Expires = expires + save = true +} + +// Delete (prematurely) cancels and deletes a notification. +func (n *Notification) Delete() { + // Dismiss notification. + func() { + n.lock.Lock() + defer n.lock.Unlock() + + if n.actionTrigger != nil { + close(n.actionTrigger) + n.actionTrigger = nil + } + }() + + n.delete(true) +} + +// delete deletes the notification from the internal storage. It locks the +// notification, so it must not be locked when delete is called. +func (n *Notification) delete(pushUpdate bool) { + var id string + + // Delete notification after processing deletion. + defer func() { + // Lock and delete from notification storage. + notsLock.Lock() + defer notsLock.Unlock() + delete(nots, id) + }() + + // We do not access EventData here, so it is enough to just lock the + // notification itself. + n.lock.Lock() + defer n.lock.Unlock() + + // Save ID for deletion + id = n.EventID + + // Mark notification as deleted. + n.Meta().Delete() + + // Close expiry channel if available. + if n.expiredTrigger != nil { + close(n.expiredTrigger) + n.expiredTrigger = nil + } + + // Push update via the database system if needed. + if pushUpdate { + dbController.PushUpdate(n) + } + + n.resolveModuleFailure() +} + +// Expired notifies the caller when the notification has expired. +func (n *Notification) Expired() <-chan struct{} { + n.lock.Lock() + defer n.lock.Unlock() + + if n.expiredTrigger == nil { + n.expiredTrigger = make(chan struct{}) + } + + return n.expiredTrigger +} + +// selectAndExecuteAction sets the user response and executes/triggers the action, if possible. +func (n *Notification) selectAndExecuteAction(id string) { + if n.State != Active { + return + } + + n.State = Responded + n.SelectedActionID = id + + executed := false + if n.actionFunction != nil { + module.StartWorker("notification action execution", func(ctx context.Context) error { + return n.actionFunction(ctx, n) + }) + executed = true + } + + if n.actionTrigger != nil { + // satisfy all listeners (if they are listening) + // TODO(ppacher): if we miss to notify the waiter here (because + // nobody is listeing on actionTrigger) we wil likely + // never be able to execute the action again (simply because + // we won't try). May consider replacing the single actionTrigger + // channel with a per-listener (buffered) one so we just send + // the value and close the channel. + triggerAll: + for { + select { + case n.actionTrigger <- n.SelectedActionID: + executed = true + case <-time.After(100 * time.Millisecond): // mitigate race conditions + break triggerAll + } + } + } + + if executed { + n.State = Executed + n.resolveModuleFailure() + } +} + +// Lock locks the Notification. If EventData is set and +// implements sync.Locker it is locked as well. Users that +// want to replace the EventData on a notification must +// ensure to unlock the current value on their own. If the +// new EventData implements sync.Locker as well, it must +// be locked prior to unlocking the notification. +func (n *Notification) Lock() { + n.lock.Lock() + if locker, ok := n.EventData.(sync.Locker); ok { + locker.Lock() + } +} + +// Unlock unlocks the Notification and the EventData, if +// it implements sync.Locker. See Lock() for more information +// on how to replace and work with EventData. +func (n *Notification) Unlock() { + n.lock.Unlock() + if locker, ok := n.EventData.(sync.Locker); ok { + locker.Unlock() + } +} diff --git a/base/rng/doc.go b/base/rng/doc.go new file mode 100644 index 000000000..b7fed113e --- /dev/null +++ b/base/rng/doc.go @@ -0,0 +1,9 @@ +// Package rng provides a feedable CSPRNG. +// +// CSPRNG used is fortuna: github.com/seehuhn/fortuna +// By default the CSPRNG is fed by two sources: +// - It starts with a seed from `crypto/rand` and periodically reseeds from there +// - A really simple tickfeeder which extracts entropy from the internal go scheduler using goroutines and is meant to be used under load. +// +// The RNG can also be easily fed with additional sources. +package rng diff --git a/base/rng/entropy.go b/base/rng/entropy.go new file mode 100644 index 000000000..b81e2bfde --- /dev/null +++ b/base/rng/entropy.go @@ -0,0 +1,124 @@ +package rng + +import ( + "context" + "encoding/binary" + + "github.com/tevino/abool" + + "github.com/safing/portmaster/base/container" +) + +const ( + minFeedEntropy = 256 +) + +var rngFeeder = make(chan []byte) + +// The Feeder is used to feed entropy to the RNG. +type Feeder struct { + input chan *entropyData + entropy int64 + needsEntropy *abool.AtomicBool + buffer *container.Container +} + +type entropyData struct { + data []byte + entropy int +} + +// NewFeeder returns a new entropy Feeder. +func NewFeeder() *Feeder { + newFeeder := &Feeder{ + input: make(chan *entropyData), + needsEntropy: abool.NewBool(true), + buffer: container.New(), + } + module.StartServiceWorker("feeder", 0, newFeeder.run) + return newFeeder +} + +// NeedsEntropy returns whether the feeder is currently gathering entropy. +func (f *Feeder) NeedsEntropy() bool { + return f.needsEntropy.IsSet() +} + +// SupplyEntropy supplies entropy to the Feeder, it will block until the Feeder has read from it. +func (f *Feeder) SupplyEntropy(data []byte, entropy int) { + f.input <- &entropyData{ + data: data, + entropy: entropy, + } +} + +// SupplyEntropyIfNeeded supplies entropy to the Feeder, but will not block if no entropy is currently needed. +func (f *Feeder) SupplyEntropyIfNeeded(data []byte, entropy int) { + if f.needsEntropy.IsSet() { + return + } + + select { + case f.input <- &entropyData{ + data: data, + entropy: entropy, + }: + default: + } +} + +// SupplyEntropyAsInt supplies entropy to the Feeder, it will block until the Feeder has read from it. +func (f *Feeder) SupplyEntropyAsInt(n int64, entropy int) { + b := make([]byte, 8) + binary.LittleEndian.PutUint64(b, uint64(n)) + f.SupplyEntropy(b, entropy) +} + +// SupplyEntropyAsIntIfNeeded supplies entropy to the Feeder, but will not block if no entropy is currently needed. +func (f *Feeder) SupplyEntropyAsIntIfNeeded(n int64, entropy int) { + if f.needsEntropy.IsSet() { // avoid allocating a slice if possible + b := make([]byte, 8) + binary.LittleEndian.PutUint64(b, uint64(n)) + f.SupplyEntropyIfNeeded(b, entropy) + } +} + +// CloseFeeder stops the feed processing - the responsible goroutine exits. The input channel is closed and the feeder may not be used anymore in any way. +func (f *Feeder) CloseFeeder() { + close(f.input) +} + +func (f *Feeder) run(ctx context.Context) error { + defer f.needsEntropy.UnSet() + + for { + // gather + f.needsEntropy.Set() + gather: + for { + select { + case newEntropy := <-f.input: + // check if feed has been closed + if newEntropy == nil { + return nil + } + // append to buffer + f.buffer.Append(newEntropy.data) + f.entropy += int64(newEntropy.entropy) + if f.entropy >= minFeedEntropy { + break gather + } + case <-ctx.Done(): + return nil + } + } + // feed + f.needsEntropy.UnSet() + select { + case rngFeeder <- f.buffer.CompileData(): + case <-ctx.Done(): + return nil + } + f.buffer = container.New() + } +} diff --git a/base/rng/entropy_test.go b/base/rng/entropy_test.go new file mode 100644 index 000000000..76ebb4c78 --- /dev/null +++ b/base/rng/entropy_test.go @@ -0,0 +1,73 @@ +package rng + +import ( + "testing" + "time" +) + +func TestFeeder(t *testing.T) { + t.Parallel() + + // wait for start / first round to complete + time.Sleep(1 * time.Millisecond) + + f := NewFeeder() + + // go through all functions + f.NeedsEntropy() + f.SupplyEntropy([]byte{0}, 0) + f.SupplyEntropyAsInt(0, 0) + f.SupplyEntropyIfNeeded([]byte{0}, 0) + f.SupplyEntropyAsIntIfNeeded(0, 0) + + // fill entropy + f.SupplyEntropyAsInt(0, 65535) + + // check blocking calls + + waitOne := make(chan struct{}) + go func() { + f.SupplyEntropy([]byte{0}, 0) + close(waitOne) + }() + select { + case <-waitOne: + t.Error("call does not block!") + case <-time.After(10 * time.Millisecond): + } + + waitTwo := make(chan struct{}) + go func() { + f.SupplyEntropyAsInt(0, 0) + close(waitTwo) + }() + select { + case <-waitTwo: + t.Error("call does not block!") + case <-time.After(10 * time.Millisecond): + } + + // check non-blocking calls + + waitThree := make(chan struct{}) + go func() { + f.SupplyEntropyIfNeeded([]byte{0}, 0) + close(waitThree) + }() + select { + case <-waitThree: + case <-time.After(10 * time.Millisecond): + t.Error("call blocks!") + } + + waitFour := make(chan struct{}) + go func() { + f.SupplyEntropyAsIntIfNeeded(0, 0) + close(waitFour) + }() + select { + case <-waitFour: + case <-time.After(10 * time.Millisecond): + t.Error("call blocks!") + } +} diff --git a/base/rng/fullfeed.go b/base/rng/fullfeed.go new file mode 100644 index 000000000..e055f3e1d --- /dev/null +++ b/base/rng/fullfeed.go @@ -0,0 +1,43 @@ +package rng + +import ( + "context" + "time" +) + +func getFullFeedDuration() time.Duration { + // full feed every 5x time of reseedAfterSeconds + secsUntilFullFeed := reseedAfterSeconds * 5 + + // full feed at most once every ten minutes + if secsUntilFullFeed < 600 { + secsUntilFullFeed = 600 + } + + return time.Duration(secsUntilFullFeed) * time.Second +} + +func fullFeeder(ctx context.Context) error { + fullFeedDuration := getFullFeedDuration() + + for { + select { + case <-time.After(fullFeedDuration): + + rngLock.Lock() + feedAll: + for { + select { + case data := <-rngFeeder: + rng.Reseed(data) + default: + break feedAll + } + } + rngLock.Unlock() + + case <-ctx.Done(): + return nil + } + } +} diff --git a/base/rng/fullfeed_test.go b/base/rng/fullfeed_test.go new file mode 100644 index 000000000..c3da8373f --- /dev/null +++ b/base/rng/fullfeed_test.go @@ -0,0 +1,15 @@ +package rng + +import ( + "testing" +) + +func TestFullFeeder(t *testing.T) { + t.Parallel() + + for i := 0; i < 10; i++ { + go func() { + rngFeeder <- []byte{0} + }() + } +} diff --git a/base/rng/get.go b/base/rng/get.go new file mode 100644 index 000000000..3ff97be86 --- /dev/null +++ b/base/rng/get.go @@ -0,0 +1,94 @@ +package rng + +import ( + "encoding/binary" + "errors" + "io" + "math" + "time" +) + +const ( + reseedAfterSeconds = 600 // ten minutes + reseedAfterBytes = 1048576 // one megabyte +) + +var ( + // Reader provides a global instance to read from the RNG. + Reader io.Reader + + rngBytesRead uint64 + rngLastFeed = time.Now() +) + +// reader provides an io.Reader interface. +type reader struct{} + +func init() { + Reader = reader{} +} + +func checkEntropy() (err error) { + if !rngReady { + return errors.New("RNG is not ready yet") + } + if rngBytesRead > reseedAfterBytes || + int(time.Since(rngLastFeed).Seconds()) > reseedAfterSeconds { + select { + case r := <-rngFeeder: + rng.Reseed(r) + rngBytesRead = 0 + rngLastFeed = time.Now() + case <-time.After(1 * time.Second): + return errors.New("failed to get new entropy") + } + } + return nil +} + +// Read reads random bytes into the supplied byte slice. +func Read(b []byte) (n int, err error) { + rngLock.Lock() + defer rngLock.Unlock() + + if err := checkEntropy(); err != nil { + return 0, err + } + + return copy(b, rng.PseudoRandomData(uint(len(b)))), nil +} + +// Read implements the io.Reader interface. +func (r reader) Read(b []byte) (n int, err error) { + return Read(b) +} + +// Bytes allocates a new byte slice of given length and fills it with random data. +func Bytes(n int) ([]byte, error) { + rngLock.Lock() + defer rngLock.Unlock() + + if err := checkEntropy(); err != nil { + return nil, err + } + + return rng.PseudoRandomData(uint(n)), nil +} + +// Number returns a random number from 0 to (incl.) max. +func Number(max uint64) (uint64, error) { + secureLimit := math.MaxUint64 - (math.MaxUint64 % max) + max++ + + for { + randomBytes, err := Bytes(8) + if err != nil { + return 0, err + } + + candidate := binary.LittleEndian.Uint64(randomBytes) + if candidate < secureLimit { + return candidate % max, nil + } + } +} diff --git a/base/rng/get_test.go b/base/rng/get_test.go new file mode 100644 index 000000000..b92003747 --- /dev/null +++ b/base/rng/get_test.go @@ -0,0 +1,41 @@ +package rng + +import ( + "testing" +) + +func TestNumberRandomness(t *testing.T) { + t.Parallel() + + // skip in automated tests + t.Logf("Integer number bias test deactivated, as it sometimes triggers.") + t.SkipNow() + + if testing.Short() { + t.Skip() + } + + var subjects uint64 = 10 + var testSize uint64 = 10000 + + results := make([]uint64, int(subjects)) + for i := 0; i < int(subjects*testSize); i++ { + n, err := Number(subjects - 1) + if err != nil { + t.Fatal(err) + return + } + results[int(n)]++ + } + + // catch big mistakes in the number function, eg. massive % bias + lowerMargin := testSize - testSize/50 + upperMargin := testSize + testSize/50 + for subject, result := range results { + if result < lowerMargin || result > upperMargin { + t.Errorf("subject %d is outside of margins: %d", subject, result) + } + } + + t.Fatal(results) +} diff --git a/base/rng/osfeeder.go b/base/rng/osfeeder.go new file mode 100644 index 000000000..36aa8e4d0 --- /dev/null +++ b/base/rng/osfeeder.go @@ -0,0 +1,35 @@ +package rng + +import ( + "context" + "crypto/rand" + "fmt" +) + +func osFeeder(ctx context.Context) error { + entropyBytes := minFeedEntropy / 8 + feeder := NewFeeder() + defer feeder.CloseFeeder() + + for { + // gather + osEntropy := make([]byte, entropyBytes) + n, err := rand.Read(osEntropy) + if err != nil { + return fmt.Errorf("could not read entropy from os: %w", err) + } + if n != entropyBytes { + return fmt.Errorf("could not read enough entropy from os: got only %d bytes instead of %d", n, entropyBytes) + } + + // feed + select { + case feeder.input <- &entropyData{ + data: osEntropy, + entropy: entropyBytes * 8, + }: + case <-ctx.Done(): + return nil + } + } +} diff --git a/base/rng/rng.go b/base/rng/rng.go new file mode 100644 index 000000000..fa9bf5ca3 --- /dev/null +++ b/base/rng/rng.go @@ -0,0 +1,81 @@ +package rng + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "errors" + "fmt" + "sync" + + "github.com/aead/serpent" + "github.com/seehuhn/fortuna" + + "github.com/safing/portmaster/base/modules" +) + +var ( + rng *fortuna.Generator + rngLock sync.Mutex + rngReady = false + + rngCipher = "aes" + // Possible values: "aes", "serpent". + + module *modules.Module +) + +func init() { + module = modules.Register("rng", nil, start, nil) +} + +func newCipher(key []byte) (cipher.Block, error) { + switch rngCipher { + case "aes": + return aes.NewCipher(key) + case "serpent": + return serpent.NewCipher(key) + default: + return nil, fmt.Errorf("unknown or unsupported cipher: %s", rngCipher) + } +} + +func start() error { + rngLock.Lock() + defer rngLock.Unlock() + + rng = fortuna.NewGenerator(newCipher) + if rng == nil { + return errors.New("failed to initialize rng") + } + + // add another (async) OS rng seed + module.StartWorker("initial rng feed", func(_ context.Context) error { + // get entropy from OS + osEntropy := make([]byte, minFeedEntropy/8) + _, err := rand.Read(osEntropy) + if err != nil { + return fmt.Errorf("could not read entropy from os: %w", err) + } + // feed + rngLock.Lock() + rng.Reseed(osEntropy) + rngLock.Unlock() + return nil + }) + + // mark as ready + rngReady = true + + // random source: OS + module.StartServiceWorker("os rng feeder", 0, osFeeder) + + // random source: goroutine ticks + module.StartServiceWorker("tick rng feeder", 0, tickFeeder) + + // full feeder + module.StartServiceWorker("full feeder", 0, fullFeeder) + + return nil +} diff --git a/base/rng/rng_test.go b/base/rng/rng_test.go new file mode 100644 index 000000000..be70e17df --- /dev/null +++ b/base/rng/rng_test.go @@ -0,0 +1,50 @@ +package rng + +import ( + "testing" +) + +func init() { + err := start() + if err != nil { + panic(err) + } +} + +func TestRNG(t *testing.T) { + t.Parallel() + + key := make([]byte, 16) + + rngCipher = "aes" + _, err := newCipher(key) + if err != nil { + t.Errorf("failed to create aes cipher: %s", err) + } + + rngCipher = "serpent" + _, err = newCipher(key) + if err != nil { + t.Errorf("failed to create serpent cipher: %s", err) + } + + b := make([]byte, 32) + _, err = Read(b) + if err != nil { + t.Errorf("Read failed: %s", err) + } + _, err = Reader.Read(b) + if err != nil { + t.Errorf("Read failed: %s", err) + } + + _, err = Bytes(32) + if err != nil { + t.Errorf("Bytes failed: %s", err) + } + + _, err = Number(100) + if err != nil { + t.Errorf("Number failed: %s", err) + } +} diff --git a/base/rng/test/.gitignore b/base/rng/test/.gitignore new file mode 100644 index 000000000..a13540cf9 --- /dev/null +++ b/base/rng/test/.gitignore @@ -0,0 +1,4 @@ +test +*.bin +*.out +*.txt diff --git a/base/rng/test/README.md b/base/rng/test/README.md new file mode 100644 index 000000000..3ea732bbc --- /dev/null +++ b/base/rng/test/README.md @@ -0,0 +1,279 @@ +# Entropy Testing + +In order verify that the random package actually generates random enough entropy/data, this test program holds the core functions that generate entropy as well as some noise makers to simulate a running program. + +Please also note that output from `tickFeeder` is never used directly but fed as entropy to the actual RNG - `fortuna`. + +With `tickFeeder`, to be sure that the delivered entropy is of high enough quality, only 1 bit of entropy is expected per generated byte - ie. we gather 8 times the amount we need. The following test below is run on the raw output. + +To test the quality of entropy, first generate random data with the test program: + + go build + + ./test tickfeeder tickfeeder.out 1 # just the additional entropy feed + # OR + ./test fortuna fortuna.out 10 # the actual CSPRNG with feeders + +Then, run `dieharder`, a random number generator test tool: + + dieharder -a -f output.bin + +Below you can find two test outputs of `dieharder`. +Please note that around 5 tests of `dieharder` normally fail. This is expected and even desired. +Also, the rng currently reseeds (ie. adds entropy) after 1MB or 10 minutes. + +`dieharder` of two samples of 10MB of fortuna (with feeders) (`go version go1.14.2 linux/amd64` on 21.04.2020): + + #=============================================================================# + # dieharder version 3.31.1 Copyright 2003 Robert G. Brown # + #=============================================================================# + rng_name | filename |rands/second| + mt19937| fortuna.out| 1.00e+08 | + #=============================================================================# + test_name |ntup| tsamples |psamples| p-value |Assessment + #=============================================================================# + diehard_birthdays| 0| 100| 100|0.69048981| PASSED | 2nd sample: PASSED + diehard_operm5| 0| 1000000| 100|0.76010702| PASSED | 2nd sample: PASSED + diehard_rank_32x32| 0| 40000| 100|0.86291558| PASSED | 2nd sample: PASSED + diehard_rank_6x8| 0| 100000| 100|0.63715647| PASSED | 2nd sample: PASSED + diehard_bitstream| 0| 2097152| 100|0.25389670| PASSED | 2nd sample: PASSED + diehard_opso| 0| 2097152| 100|0.70928590| PASSED | 2nd sample: PASSED + diehard_oqso| 0| 2097152| 100|0.75643141| PASSED | 2nd sample: PASSED + diehard_dna| 0| 2097152| 100|0.57096286| PASSED | 2nd sample: PASSED + diehard_count_1s_str| 0| 256000| 100|0.39650366| PASSED | 2nd sample: PASSED + diehard_count_1s_byt| 0| 256000| 100|0.26040557| PASSED | 2nd sample: PASSED + diehard_parking_lot| 0| 12000| 100|0.92327672| PASSED | 2nd sample: PASSED + diehard_2dsphere| 2| 8000| 100|0.86507605| PASSED | 2nd sample: PASSED + diehard_3dsphere| 3| 4000| 100|0.70845388| PASSED | 2nd sample: PASSED + diehard_squeeze| 0| 100000| 100|0.99744782| WEAK | 2nd sample: PASSED + diehard_sums| 0| 100| 100|0.27275938| PASSED | 2nd sample: PASSED + diehard_runs| 0| 100000| 100|0.27299936| PASSED | 2nd sample: PASSED + diehard_runs| 0| 100000| 100|0.42043270| PASSED | 2nd sample: PASSED + diehard_craps| 0| 200000| 100|0.91674884| PASSED | 2nd sample: PASSED + diehard_craps| 0| 200000| 100|0.77856237| PASSED | 2nd sample: PASSED + marsaglia_tsang_gcd| 0| 10000000| 100|0.77922797| PASSED | 2nd sample: PASSED + marsaglia_tsang_gcd| 0| 10000000| 100|0.94589532| PASSED | 2nd sample: PASSED + sts_monobit| 1| 100000| 100|0.99484549| PASSED | 2nd sample: PASSED + sts_runs| 2| 100000| 100|0.70036713| PASSED | 2nd sample: PASSED + sts_serial| 1| 100000| 100|0.79544015| PASSED | 2nd sample: PASSED + sts_serial| 2| 100000| 100|0.91473958| PASSED | 2nd sample: PASSED + sts_serial| 3| 100000| 100|0.66528037| PASSED | 2nd sample: PASSED + sts_serial| 3| 100000| 100|0.84028312| PASSED | 2nd sample: PASSED + sts_serial| 4| 100000| 100|0.82253130| PASSED | 2nd sample: PASSED + sts_serial| 4| 100000| 100|0.90695315| PASSED | 2nd sample: PASSED + sts_serial| 5| 100000| 100|0.55160515| PASSED | 2nd sample: PASSED + sts_serial| 5| 100000| 100|0.05256789| PASSED | 2nd sample: PASSED + sts_serial| 6| 100000| 100|0.25857850| PASSED | 2nd sample: PASSED + sts_serial| 6| 100000| 100|0.58661649| PASSED | 2nd sample: PASSED + sts_serial| 7| 100000| 100|0.46915559| PASSED | 2nd sample: PASSED + sts_serial| 7| 100000| 100|0.57273130| PASSED | 2nd sample: PASSED + sts_serial| 8| 100000| 100|0.99182961| PASSED | 2nd sample: PASSED + sts_serial| 8| 100000| 100|0.86913367| PASSED | 2nd sample: PASSED + sts_serial| 9| 100000| 100|0.19259756| PASSED | 2nd sample: PASSED + sts_serial| 9| 100000| 100|0.61225842| PASSED | 2nd sample: PASSED + sts_serial| 10| 100000| 100|0.40792308| PASSED | 2nd sample: PASSED + sts_serial| 10| 100000| 100|0.99930785| WEAK | 2nd sample: PASSED + sts_serial| 11| 100000| 100|0.07296973| PASSED | 2nd sample: PASSED + sts_serial| 11| 100000| 100|0.04906522| PASSED | 2nd sample: PASSED + sts_serial| 12| 100000| 100|0.66400927| PASSED | 2nd sample: PASSED + sts_serial| 12| 100000| 100|0.67947609| PASSED | 2nd sample: PASSED + sts_serial| 13| 100000| 100|0.20412325| PASSED | 2nd sample: PASSED + sts_serial| 13| 100000| 100|0.19781734| PASSED | 2nd sample: PASSED + sts_serial| 14| 100000| 100|0.08541533| PASSED | 2nd sample: PASSED + sts_serial| 14| 100000| 100|0.07438464| PASSED | 2nd sample: PASSED + sts_serial| 15| 100000| 100|0.04607276| PASSED | 2nd sample: PASSED + sts_serial| 15| 100000| 100|0.56460340| PASSED | 2nd sample: PASSED + sts_serial| 16| 100000| 100|0.40211405| PASSED | 2nd sample: PASSED + sts_serial| 16| 100000| 100|0.81369172| PASSED | 2nd sample: PASSED + rgb_bitdist| 1| 100000| 100|0.52317549| PASSED | 2nd sample: PASSED + rgb_bitdist| 2| 100000| 100|0.49819655| PASSED | 2nd sample: PASSED + rgb_bitdist| 3| 100000| 100|0.65830167| PASSED | 2nd sample: PASSED + rgb_bitdist| 4| 100000| 100|0.75278398| PASSED | 2nd sample: PASSED + rgb_bitdist| 5| 100000| 100|0.23537303| PASSED | 2nd sample: PASSED + rgb_bitdist| 6| 100000| 100|0.82461608| PASSED | 2nd sample: PASSED + rgb_bitdist| 7| 100000| 100|0.46944789| PASSED | 2nd sample: PASSED + rgb_bitdist| 8| 100000| 100|0.44371293| PASSED | 2nd sample: PASSED + rgb_bitdist| 9| 100000| 100|0.61647469| PASSED | 2nd sample: PASSED + rgb_bitdist| 10| 100000| 100|0.97623808| PASSED | 2nd sample: PASSED + rgb_bitdist| 11| 100000| 100|0.26037998| PASSED | 2nd sample: PASSED + rgb_bitdist| 12| 100000| 100|0.59217788| PASSED | 2nd sample: PASSED + rgb_minimum_distance| 2| 10000| 1000|0.19809129| PASSED | 2nd sample: PASSED + rgb_minimum_distance| 3| 10000| 1000|0.97363365| PASSED | 2nd sample: PASSED + rgb_minimum_distance| 4| 10000| 1000|0.62281709| PASSED | 2nd sample: PASSED + rgb_minimum_distance| 5| 10000| 1000|0.13655852| PASSED | 2nd sample: PASSED + rgb_permutations| 2| 100000| 100|0.33726465| PASSED | 2nd sample: PASSED + rgb_permutations| 3| 100000| 100|0.21992025| PASSED | 2nd sample: WEAK + rgb_permutations| 4| 100000| 100|0.27074573| PASSED | 2nd sample: PASSED + rgb_permutations| 5| 100000| 100|0.76925248| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 0| 1000000| 100|0.91881971| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 1| 1000000| 100|0.08282106| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 2| 1000000| 100|0.55991289| PASSED | 2nd sample: WEAK + rgb_lagged_sum| 3| 1000000| 100|0.94939920| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 4| 1000000| 100|0.21248759| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 5| 1000000| 100|0.99308883| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 6| 1000000| 100|0.83174944| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 7| 1000000| 100|0.49883983| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 8| 1000000| 100|0.99900807| WEAK | 2nd sample: PASSED + rgb_lagged_sum| 9| 1000000| 100|0.74164128| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 10| 1000000| 100|0.53367081| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 11| 1000000| 100|0.41808417| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 12| 1000000| 100|0.96082733| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 13| 1000000| 100|0.38208924| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 14| 1000000| 100|0.98335747| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 15| 1000000| 100|0.68708033| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 16| 1000000| 100|0.49715110| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 17| 1000000| 100|0.68418225| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 18| 1000000| 100|0.97255087| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 19| 1000000| 100|0.99556843| WEAK | 2nd sample: PASSED + rgb_lagged_sum| 20| 1000000| 100|0.50758123| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 21| 1000000| 100|0.98435826| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 22| 1000000| 100|0.15752743| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 23| 1000000| 100|0.98607886| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 24| 1000000| 100|0.86645723| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 25| 1000000| 100|0.87384758| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 26| 1000000| 100|0.98680940| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 27| 1000000| 100|0.56386729| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 28| 1000000| 100|0.16874165| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 29| 1000000| 100|0.10369211| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 30| 1000000| 100|0.91356341| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 31| 1000000| 100|0.42526940| PASSED | 2nd sample: PASSED + rgb_lagged_sum| 32| 1000000| 100|0.99939460| WEAK | 2nd sample: PASSED + rgb_kstest_test| 0| 10000| 1000|0.11414525| PASSED | 2nd sample: PASSED + dab_bytedistrib| 0| 51200000| 1|0.27693890| PASSED | 2nd sample: PASSED + dab_dct| 256| 50000| 1|0.15807123| PASSED | 2nd sample: PASSED + Preparing to run test 207. ntuple = 0 + dab_filltree| 32| 15000000| 1|0.33275771| PASSED | 2nd sample: PASSED + dab_filltree| 32| 15000000| 1|0.15704033| PASSED | 2nd sample: PASSED + Preparing to run test 208. ntuple = 0 + dab_filltree2| 0| 5000000| 1|0.85562670| PASSED | 2nd sample: PASSED + dab_filltree2| 1| 5000000| 1|0.35187836| PASSED | 2nd sample: PASSED + Preparing to run test 209. ntuple = 0 + dab_monobit2| 12| 65000000| 1|0.03099468| PASSED | 2nd sample: PASSED + +`dieharder` output of 22KB of contextswitch (`go version go1.10.3 linux/amd64` on 23.08.2018): + + #=============================================================================# + # dieharder version 3.31.1 Copyright 2003 Robert G. Brown # + #=============================================================================# + rng_name | filename |rands/second| + mt19937| output.bin| 1.00e+08 | + #=============================================================================# + test_name |ntup| tsamples |psamples| p-value |Assessment + #=============================================================================# + diehard_birthdays| 0| 100| 100|0.75124818| PASSED + diehard_operm5| 0| 1000000| 100|0.71642114| PASSED + diehard_rank_32x32| 0| 40000| 100|0.66406749| PASSED + diehard_rank_6x8| 0| 100000| 100|0.79742497| PASSED + diehard_bitstream| 0| 2097152| 100|0.68336079| PASSED + diehard_opso| 0| 2097152| 100|0.99670345| WEAK + diehard_oqso| 0| 2097152| 100|0.85930861| PASSED + diehard_dna| 0| 2097152| 100|0.77857540| PASSED + diehard_count_1s_str| 0| 256000| 100|0.27851730| PASSED + diehard_count_1s_byt| 0| 256000| 100|0.29570009| PASSED + diehard_parking_lot| 0| 12000| 100|0.51526020| PASSED + diehard_2dsphere| 2| 8000| 100|0.49199324| PASSED + diehard_3dsphere| 3| 4000| 100|0.99008122| PASSED + diehard_squeeze| 0| 100000| 100|0.95518110| PASSED + diehard_sums| 0| 100| 100|0.00015930| WEAK + diehard_runs| 0| 100000| 100|0.50091086| PASSED + diehard_runs| 0| 100000| 100|0.44091340| PASSED + diehard_craps| 0| 200000| 100|0.77284264| PASSED + diehard_craps| 0| 200000| 100|0.71027434| PASSED + marsaglia_tsang_gcd| 0| 10000000| 100|0.38138922| PASSED + marsaglia_tsang_gcd| 0| 10000000| 100|0.36661590| PASSED + sts_monobit| 1| 100000| 100|0.06209802| PASSED + sts_runs| 2| 100000| 100|0.82506539| PASSED + sts_serial| 1| 100000| 100|0.99198615| PASSED + sts_serial| 2| 100000| 100|0.85604831| PASSED + sts_serial| 3| 100000| 100|0.06613657| PASSED + sts_serial| 3| 100000| 100|0.16787860| PASSED + sts_serial| 4| 100000| 100|0.45227401| PASSED + sts_serial| 4| 100000| 100|0.43529092| PASSED + sts_serial| 5| 100000| 100|0.99912474| WEAK + sts_serial| 5| 100000| 100|0.94754128| PASSED + sts_serial| 6| 100000| 100|0.98406523| PASSED + sts_serial| 6| 100000| 100|0.92895983| PASSED + sts_serial| 7| 100000| 100|0.45965410| PASSED + sts_serial| 7| 100000| 100|0.64185152| PASSED + sts_serial| 8| 100000| 100|0.57922926| PASSED + sts_serial| 8| 100000| 100|0.52390292| PASSED + sts_serial| 9| 100000| 100|0.82722325| PASSED + sts_serial| 9| 100000| 100|0.89384819| PASSED + sts_serial| 10| 100000| 100|0.79877889| PASSED + sts_serial| 10| 100000| 100|0.49562348| PASSED + sts_serial| 11| 100000| 100|0.09217966| PASSED + sts_serial| 11| 100000| 100|0.00342361| WEAK + sts_serial| 12| 100000| 100|0.60119444| PASSED + sts_serial| 12| 100000| 100|0.20420318| PASSED + sts_serial| 13| 100000| 100|0.76867489| PASSED + sts_serial| 13| 100000| 100|0.35717970| PASSED + sts_serial| 14| 100000| 100|0.67364089| PASSED + sts_serial| 14| 100000| 100|0.98667204| PASSED + sts_serial| 15| 100000| 100|0.24328833| PASSED + sts_serial| 15| 100000| 100|0.52098866| PASSED + sts_serial| 16| 100000| 100|0.48845863| PASSED + sts_serial| 16| 100000| 100|0.61943558| PASSED + rgb_bitdist| 1| 100000| 100|0.24694812| PASSED + rgb_bitdist| 2| 100000| 100|0.75873723| PASSED + rgb_bitdist| 3| 100000| 100|0.28670990| PASSED + rgb_bitdist| 4| 100000| 100|0.41966273| PASSED + rgb_bitdist| 5| 100000| 100|0.80463973| PASSED + rgb_bitdist| 6| 100000| 100|0.44747725| PASSED + rgb_bitdist| 7| 100000| 100|0.35848420| PASSED + rgb_bitdist| 8| 100000| 100|0.56585089| PASSED + rgb_bitdist| 9| 100000| 100|0.23179559| PASSED + rgb_bitdist| 10| 100000| 100|0.83369283| PASSED + rgb_bitdist| 11| 100000| 100|0.74761235| PASSED + rgb_bitdist| 12| 100000| 100|0.50477673| PASSED + rgb_minimum_distance| 2| 10000| 1000|0.29527530| PASSED + rgb_minimum_distance| 3| 10000| 1000|0.83681186| PASSED + rgb_minimum_distance| 4| 10000| 1000|0.85939646| PASSED + rgb_minimum_distance| 5| 10000| 1000|0.90229335| PASSED + rgb_permutations| 2| 100000| 100|0.99010460| PASSED + rgb_permutations| 3| 100000| 100|0.99360922| PASSED + rgb_permutations| 4| 100000| 100|0.30113906| PASSED + rgb_permutations| 5| 100000| 100|0.60701235| PASSED + rgb_lagged_sum| 0| 1000000| 100|0.37080580| PASSED + rgb_lagged_sum| 1| 1000000| 100|0.91852932| PASSED + rgb_lagged_sum| 2| 1000000| 100|0.74568323| PASSED + rgb_lagged_sum| 3| 1000000| 100|0.64070201| PASSED + rgb_lagged_sum| 4| 1000000| 100|0.53802729| PASSED + rgb_lagged_sum| 5| 1000000| 100|0.67865656| PASSED + rgb_lagged_sum| 6| 1000000| 100|0.85161494| PASSED + rgb_lagged_sum| 7| 1000000| 100|0.37312323| PASSED + rgb_lagged_sum| 8| 1000000| 100|0.17841759| PASSED + rgb_lagged_sum| 9| 1000000| 100|0.85795513| PASSED + rgb_lagged_sum| 10| 1000000| 100|0.79843176| PASSED + rgb_lagged_sum| 11| 1000000| 100|0.21320830| PASSED + rgb_lagged_sum| 12| 1000000| 100|0.94709672| PASSED + rgb_lagged_sum| 13| 1000000| 100|0.12600611| PASSED + rgb_lagged_sum| 14| 1000000| 100|0.26780352| PASSED + rgb_lagged_sum| 15| 1000000| 100|0.07862730| PASSED + rgb_lagged_sum| 16| 1000000| 100|0.21102254| PASSED + rgb_lagged_sum| 17| 1000000| 100|0.82967141| PASSED + rgb_lagged_sum| 18| 1000000| 100|0.05818566| PASSED + rgb_lagged_sum| 19| 1000000| 100|0.01010140| PASSED + rgb_lagged_sum| 20| 1000000| 100|0.17941782| PASSED + rgb_lagged_sum| 21| 1000000| 100|0.98442639| PASSED + rgb_lagged_sum| 22| 1000000| 100|0.30352772| PASSED + rgb_lagged_sum| 23| 1000000| 100|0.56855155| PASSED + rgb_lagged_sum| 24| 1000000| 100|0.27280405| PASSED + rgb_lagged_sum| 25| 1000000| 100|0.41141889| PASSED + rgb_lagged_sum| 26| 1000000| 100|0.25389013| PASSED + rgb_lagged_sum| 27| 1000000| 100|0.10313177| PASSED + rgb_lagged_sum| 28| 1000000| 100|0.76610028| PASSED + rgb_lagged_sum| 29| 1000000| 100|0.97903830| PASSED + rgb_lagged_sum| 30| 1000000| 100|0.51216732| PASSED + rgb_lagged_sum| 31| 1000000| 100|0.98578832| PASSED + rgb_lagged_sum| 32| 1000000| 100|0.95078719| PASSED + rgb_kstest_test| 0| 10000| 1000|0.24930712| PASSED + dab_bytedistrib| 0| 51200000| 1|0.51100031| PASSED + dab_dct| 256| 50000| 1|0.28794956| PASSED + Preparing to run test 207. ntuple = 0 + dab_filltree| 32| 15000000| 1|0.93283449| PASSED + dab_filltree| 32| 15000000| 1|0.36488075| PASSED + Preparing to run test 208. ntuple = 0 + dab_filltree2| 0| 5000000| 1|0.94036105| PASSED + dab_filltree2| 1| 5000000| 1|0.30118240| PASSED + Preparing to run test 209. ntuple = 0 + dab_monobit2| 12| 65000000| 1|0.00209003| WEAK diff --git a/base/rng/test/main.go b/base/rng/test/main.go new file mode 100644 index 000000000..bc8883eeb --- /dev/null +++ b/base/rng/test/main.go @@ -0,0 +1,191 @@ +package main + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/binary" + "encoding/hex" + "fmt" + "io" + "os" + "runtime" + "strconv" + "time" + + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/rng" + "github.com/safing/portmaster/base/run" +) + +var ( + module *modules.Module + + outputFile *os.File + outputSize uint64 = 1000000 +) + +func init() { + module = modules.Register("main", prep, start, nil, "rng") +} + +func main() { + runtime.GOMAXPROCS(1) + os.Exit(run.Run()) +} + +func prep() error { + if len(os.Args) < 3 { + fmt.Printf("usage: ./%s {fortuna|tickfeeder} [output size in MB]", os.Args[0]) + return modules.ErrCleanExit + } + + switch os.Args[1] { + case "fortuna": + case "tickfeeder": + default: + return fmt.Errorf("usage: %s {fortuna|tickfeeder}", os.Args[0]) + } + + if len(os.Args) > 3 { + n, err := strconv.ParseUint(os.Args[3], 10, 64) + if err != nil { + return fmt.Errorf("failed to parse output size: %w", err) + } + outputSize = n * 1000000 + } + + var err error + outputFile, err = os.OpenFile(os.Args[2], os.O_CREATE|os.O_WRONLY, 0o0644) //nolint:gosec + if err != nil { + return fmt.Errorf("failed to open output file: %w", err) + } + + return nil +} + +//nolint:gocognit +func start() error { + // generates 1MB and writes to stdout + + log.Infof("writing %dMB to stdout, a \".\" will be printed at every 1024 bytes.", outputSize/1000000) + + switch os.Args[1] { + case "fortuna": + module.StartWorker("fortuna", fortuna) + + case "tickfeeder": + module.StartWorker("noise", noise) + module.StartWorker("tickfeeder", tickfeeder) + + default: + return fmt.Errorf("usage: ./%s {fortuna|tickfeeder}", os.Args[0]) + } + + return nil +} + +func fortuna(_ context.Context) error { + var bytesWritten uint64 + + for { + if module.IsStopping() { + return nil + } + + b, err := rng.Bytes(64) + if err != nil { + return err + } + _, err = outputFile.Write(b) + if err != nil { + return err + } + + bytesWritten += 64 + if bytesWritten%1024 == 0 { + _, _ = os.Stderr.WriteString(".") + } + if bytesWritten%65536 == 0 { + fmt.Fprintf(os.Stderr, "\n%d bytes written\n", bytesWritten) + } + if bytesWritten >= outputSize { + _, _ = os.Stderr.WriteString("\n") + break + } + } + + go modules.Shutdown() //nolint:errcheck + return nil +} + +func tickfeeder(ctx context.Context) error { + var bytesWritten uint64 + var value int64 + var pushes int + + for { + if module.IsStopping() { + return nil + } + + time.Sleep(10 * time.Nanosecond) + + value = (value << 1) | (time.Now().UnixNano() % 2) + pushes++ + + if pushes >= 64 { + b := make([]byte, 8) + binary.LittleEndian.PutUint64(b, uint64(value)) + _, err := outputFile.Write(b) + if err != nil { + return err + } + bytesWritten += 8 + if bytesWritten%1024 == 0 { + _, _ = os.Stderr.WriteString(".") + } + if bytesWritten%65536 == 0 { + fmt.Fprintf(os.Stderr, "\n%d bytes written\n", bytesWritten) + } + pushes = 0 + } + + if bytesWritten >= outputSize { + _, _ = os.Stderr.WriteString("\n") + break + } + } + + go modules.Shutdown() //nolint:errcheck + return nil +} + +func noise(ctx context.Context) error { + // do some aes ctr for noise + + key, _ := hex.DecodeString("6368616e676520746869732070617373") + data := []byte("some plaintext x") + + block, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + panic(err) + } + + stream := cipher.NewCTR(block, iv) + for { + select { + case <-ctx.Done(): + return nil + default: + stream.XORKeyStream(data, data) + } + } +} diff --git a/base/rng/tickfeeder.go b/base/rng/tickfeeder.go new file mode 100644 index 000000000..6dbe69ac0 --- /dev/null +++ b/base/rng/tickfeeder.go @@ -0,0 +1,75 @@ +package rng + +import ( + "context" + "encoding/binary" + "time" +) + +func getTickFeederTickDuration() time.Duration { + // be ready in 1/10 time of reseedAfterSeconds + msecsAvailable := reseedAfterSeconds * 100 + // ex.: reseed after 10 minutes: msecsAvailable = 60000 + // have full entropy after 5 minutes + + // one tick generates 0,125 bits of entropy + ticksNeeded := minFeedEntropy * 8 + // ex.: minimum entropy is 256: ticksNeeded = 2048 + + // msces between ticks + tickMsecs := msecsAvailable / ticksNeeded + // ex.: tickMsecs = 29(,296875) + + // use a minimum of 10 msecs per tick for good entropy + // it would take 21 seconds to get full 256 bits of entropy with 10msec ticks + if tickMsecs < 10 { + tickMsecs = 10 + } + + return time.Duration(tickMsecs) * time.Millisecond +} + +// tickFeeder is a really simple entropy feeder that adds the least significant bit of the current nanosecond unixtime to its pool every time it 'ticks'. +// The more work the program does, the better the quality, as the internal schedular cannot immediately run the goroutine when it's ready. +func tickFeeder(ctx context.Context) error { + var value int64 + var pushes int + feeder := NewFeeder() + defer feeder.CloseFeeder() + + tickDuration := getTickFeederTickDuration() + + for { + // wait for tick + time.Sleep(tickDuration) + + // add tick value + value = (value << 1) | (time.Now().UnixNano() % 2) + pushes++ + + if pushes >= 64 { + // convert to []byte + b := make([]byte, 8) + binary.LittleEndian.PutUint64(b, uint64(value)) + // reset + pushes = 0 + + // feed + select { + case feeder.input <- &entropyData{ + data: b, + entropy: 8, + }: + case <-ctx.Done(): + return nil + } + } else { + // check if are done + select { + case <-ctx.Done(): + return nil + default: + } + } + } +} diff --git a/base/runtime/module.go b/base/runtime/module.go new file mode 100644 index 000000000..169abc94e --- /dev/null +++ b/base/runtime/module.go @@ -0,0 +1,44 @@ +package runtime + +import ( + "fmt" + + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/modules" +) + +// DefaultRegistry is the default registry +// that is used by the module-level API. +var DefaultRegistry = NewRegistry() + +func init() { + modules.Register("runtime", nil, startModule, nil, "database") +} + +func startModule() error { + _, err := database.Register(&database.Database{ + Name: "runtime", + Description: "Runtime database", + StorageType: "injected", + ShadowDelete: false, + }) + if err != nil { + return err + } + + if err := DefaultRegistry.InjectAsDatabase("runtime"); err != nil { + return err + } + + if err := startModulesIntegration(); err != nil { + return fmt.Errorf("failed to start modules integration: %w", err) + } + + return nil +} + +// Register is like Registry.Register but uses +// the package DefaultRegistry. +func Register(key string, provider ValueProvider) (PushFunc, error) { + return DefaultRegistry.Register(key, provider) +} diff --git a/base/runtime/modules_integration.go b/base/runtime/modules_integration.go new file mode 100644 index 000000000..c85ac330f --- /dev/null +++ b/base/runtime/modules_integration.go @@ -0,0 +1,71 @@ +package runtime + +import ( + "fmt" + "sync" + + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" +) + +var modulesIntegrationUpdatePusher func(...record.Record) + +func startModulesIntegration() (err error) { + modulesIntegrationUpdatePusher, err = Register("modules/", &ModulesIntegration{}) + if err != nil { + return err + } + + if !modules.SetEventSubscriptionFunc(pushModuleEvent) { + log.Warningf("runtime: failed to register the modules event subscription function") + } + + return nil +} + +// ModulesIntegration provides integration with the modules system. +type ModulesIntegration struct{} + +// Set is called when the value is set from outside. +// If the runtime value is considered read-only ErrReadOnly +// should be returned. It is guaranteed that the key of +// the record passed to Set is prefixed with the key used +// to register the value provider. +func (mi *ModulesIntegration) Set(record.Record) (record.Record, error) { + return nil, ErrReadOnly +} + +// Get should return one or more records that match keyOrPrefix. +// keyOrPrefix is guaranteed to be at least the prefix used to +// register the ValueProvider. +func (mi *ModulesIntegration) Get(keyOrPrefix string) ([]record.Record, error) { + return nil, database.ErrNotFound +} + +type eventData struct { + record.Base + sync.Mutex + Data interface{} +} + +func pushModuleEvent(moduleName, eventName string, internal bool, data interface{}) { + // Create event record and set key. + eventRecord := &eventData{ + Data: data, + } + eventRecord.SetKey(fmt.Sprintf( + "runtime:modules/%s/event/%s", + moduleName, + eventName, + )) + eventRecord.UpdateMeta() + if internal { + eventRecord.Meta().MakeSecret() + eventRecord.Meta().MakeCrownJewel() + } + + // Push event to database subscriptions. + modulesIntegrationUpdatePusher(eventRecord) +} diff --git a/base/runtime/provider.go b/base/runtime/provider.go new file mode 100644 index 000000000..5951e89b6 --- /dev/null +++ b/base/runtime/provider.go @@ -0,0 +1,74 @@ +package runtime + +import ( + "errors" + + "github.com/safing/portmaster/base/database/record" +) + +var ( + // ErrReadOnly should be returned from ValueProvider.Set if a + // runtime record is considered read-only. + ErrReadOnly = errors.New("runtime record is read-only") + // ErrWriteOnly should be returned from ValueProvider.Get if + // a runtime record is considered write-only. + ErrWriteOnly = errors.New("runtime record is write-only") +) + +type ( + // PushFunc is returned when registering a new value provider + // and can be used to inform the database system about the + // availability of a new runtime record value. Similar to + // database.Controller.PushUpdate, the caller must hold + // the lock for each record passed to PushFunc. + PushFunc func(...record.Record) + + // ValueProvider provides access to a runtime-computed + // database record. + ValueProvider interface { + // Set is called when the value is set from outside. + // If the runtime value is considered read-only ErrReadOnly + // should be returned. It is guaranteed that the key of + // the record passed to Set is prefixed with the key used + // to register the value provider. + Set(record.Record) (record.Record, error) + // Get should return one or more records that match keyOrPrefix. + // keyOrPrefix is guaranteed to be at least the prefix used to + // register the ValueProvider. + Get(keyOrPrefix string) ([]record.Record, error) + } + + // SimpleValueSetterFunc is a convenience type for implementing a + // write-only value provider. + SimpleValueSetterFunc func(record.Record) (record.Record, error) + + // SimpleValueGetterFunc is a convenience type for implementing a + // read-only value provider. + SimpleValueGetterFunc func(keyOrPrefix string) ([]record.Record, error) +) + +// Set implements ValueProvider.Set and calls fn. +func (fn SimpleValueSetterFunc) Set(r record.Record) (record.Record, error) { + return fn(r) +} + +// Get implements ValueProvider.Get and returns ErrWriteOnly. +func (SimpleValueSetterFunc) Get(_ string) ([]record.Record, error) { + return nil, ErrWriteOnly +} + +// Set implements ValueProvider.Set and returns ErrReadOnly. +func (SimpleValueGetterFunc) Set(r record.Record) (record.Record, error) { + return nil, ErrReadOnly +} + +// Get implements ValueProvider.Get and calls fn. +func (fn SimpleValueGetterFunc) Get(keyOrPrefix string) ([]record.Record, error) { + return fn(keyOrPrefix) +} + +// Compile time checks. +var ( + _ ValueProvider = SimpleValueGetterFunc(nil) + _ ValueProvider = SimpleValueSetterFunc(nil) +) diff --git a/base/runtime/registry.go b/base/runtime/registry.go new file mode 100644 index 000000000..373f3ffc7 --- /dev/null +++ b/base/runtime/registry.go @@ -0,0 +1,335 @@ +package runtime + +import ( + "errors" + "fmt" + "strings" + "sync" + + "github.com/armon/go-radix" + "golang.org/x/sync/errgroup" + + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/iterator" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/database/storage" + "github.com/safing/portmaster/base/log" +) + +var ( + // ErrKeyTaken is returned when trying to register + // a value provider at database key or prefix that + // is already occupied by another provider. + ErrKeyTaken = errors.New("runtime key or prefix already used") + // ErrKeyUnmanaged is returned when a Put operation + // on an unmanaged key is performed. + ErrKeyUnmanaged = errors.New("runtime key not managed by any provider") + // ErrInjected is returned by Registry.InjectAsDatabase + // if the registry has already been injected. + ErrInjected = errors.New("registry already injected") +) + +// Registry keeps track of registered runtime +// value providers and exposes them via an +// injected database. Users normally just need +// to use the defaul registry provided by this +// package but may consider creating a dedicated +// runtime registry on their own. Registry uses +// a radix tree for value providers and their +// chosen database key/prefix. +type Registry struct { + l sync.RWMutex + providers *radix.Tree + dbController *database.Controller + dbName string +} + +// keyedValueProvider simply wraps a value provider with it's +// registration prefix. +type keyedValueProvider struct { + ValueProvider + key string +} + +// NewRegistry returns a new registry. +func NewRegistry() *Registry { + return &Registry{ + providers: radix.New(), + } +} + +func isPrefixKey(key string) bool { + return strings.HasSuffix(key, "/") +} + +// DatabaseName returns the name of the database where the +// registry has been injected. It returns an empty string +// if InjectAsDatabase has not been called. +func (r *Registry) DatabaseName() string { + r.l.RLock() + defer r.l.RUnlock() + + return r.dbName +} + +// InjectAsDatabase injects the registry as the storage +// database for name. +func (r *Registry) InjectAsDatabase(name string) error { + r.l.Lock() + defer r.l.Unlock() + + if r.dbController != nil { + return ErrInjected + } + + ctrl, err := database.InjectDatabase(name, r.asStorage()) + if err != nil { + return err + } + + r.dbName = name + r.dbController = ctrl + + return nil +} + +// Register registers a new value provider p under keyOrPrefix. The +// returned PushFunc can be used to send update notitifcations to +// database subscribers. Note that keyOrPrefix must end in '/' to be +// accepted as a prefix. +func (r *Registry) Register(keyOrPrefix string, p ValueProvider) (PushFunc, error) { + r.l.Lock() + defer r.l.Unlock() + + // search if there's a provider registered for a prefix + // that matches or is equal to keyOrPrefix. + key, _, ok := r.providers.LongestPrefix(keyOrPrefix) + if ok && (isPrefixKey(key) || key == keyOrPrefix) { + return nil, fmt.Errorf("%w: found provider on %s", ErrKeyTaken, key) + } + + // if keyOrPrefix is a prefix there must not be any provider + // registered for a key that matches keyOrPrefix. + if isPrefixKey(keyOrPrefix) { + foundProvider := "" + r.providers.WalkPrefix(keyOrPrefix, func(s string, _ interface{}) bool { + foundProvider = s + return true + }) + if foundProvider != "" { + return nil, fmt.Errorf("%w: found provider on %s", ErrKeyTaken, foundProvider) + } + } + + r.providers.Insert(keyOrPrefix, &keyedValueProvider{ + ValueProvider: TraceProvider(p), + key: keyOrPrefix, + }) + + log.Tracef("runtime: registered new provider at %s", keyOrPrefix) + + return func(records ...record.Record) { + r.l.RLock() + defer r.l.RUnlock() + + if r.dbController == nil { + return + } + + for _, rec := range records { + r.dbController.PushUpdate(rec) + } + }, nil +} + +// Get returns the runtime value that is identified by key. +// It implements the storage.Interface. +func (r *Registry) Get(key string) (record.Record, error) { + provider := r.getMatchingProvider(key) + if provider == nil { + return nil, database.ErrNotFound + } + + records, err := provider.Get(key) + if err != nil { + // instead of returning ErrWriteOnly to the database interface + // we wrap it in ErrNotFound so the records effectively gets + // hidden. + if errors.Is(err, ErrWriteOnly) { + return nil, database.ErrNotFound + } + return nil, err + } + + // Get performs an exact match so filter out + // and values that do not match key. + for _, r := range records { + if r.DatabaseKey() == key { + return r, nil + } + } + + return nil, database.ErrNotFound +} + +// Put stores the record m in the runtime database. Note that +// ErrReadOnly is returned if there's no value provider responsible +// for m.Key(). +func (r *Registry) Put(m record.Record) (record.Record, error) { + provider := r.getMatchingProvider(m.DatabaseKey()) + if provider == nil { + // if there's no provider for the given value + // return ErrKeyUnmanaged. + return nil, ErrKeyUnmanaged + } + + res, err := provider.Set(m) + if err != nil { + return nil, err + } + return res, nil +} + +// Query performs a query on the runtime registry returning all +// records across all value providers that match q. +// Query implements the storage.Storage interface. +func (r *Registry) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { + if _, err := q.Check(); err != nil { + return nil, fmt.Errorf("invalid query: %w", err) + } + + searchPrefix := q.DatabaseKeyPrefix() + providers := r.collectProviderByPrefix(searchPrefix) + if len(providers) == 0 { + return nil, fmt.Errorf("%w: for key %s", ErrKeyUnmanaged, searchPrefix) + } + + iter := iterator.New() + + grp := new(errgroup.Group) + for idx := range providers { + p := providers[idx] + + grp.Go(func() (err error) { + defer recovery(&err) + + key := p.key + if len(searchPrefix) > len(key) { + key = searchPrefix + } + + records, err := p.Get(key) + if err != nil { + if errors.Is(err, ErrWriteOnly) { + return nil + } + return err + } + + for _, r := range records { + r.Lock() + var ( + matchesKey = q.MatchesKey(r.DatabaseKey()) + isValid = r.Meta().CheckValidity() + isAllowed = r.Meta().CheckPermission(local, internal) + + allowed = matchesKey && isValid && isAllowed + ) + if allowed { + allowed = q.MatchesRecord(r) + } + r.Unlock() + + if !allowed { + log.Tracef("runtime: not sending %s for query %s. matchesKey=%v isValid=%v isAllowed=%v", r.DatabaseKey(), searchPrefix, matchesKey, isValid, isAllowed) + continue + } + + select { + case iter.Next <- r: + case <-iter.Done: + return nil + } + } + + return nil + }) + } + + go func() { + err := grp.Wait() + iter.Finish(err) + }() + + return iter, nil +} + +func (r *Registry) getMatchingProvider(key string) *keyedValueProvider { + r.l.RLock() + defer r.l.RUnlock() + + providerKey, provider, ok := r.providers.LongestPrefix(key) + if !ok { + return nil + } + + if !isPrefixKey(providerKey) && providerKey != key { + return nil + } + + return provider.(*keyedValueProvider) //nolint:forcetypeassert +} + +func (r *Registry) collectProviderByPrefix(prefix string) []*keyedValueProvider { + r.l.RLock() + defer r.l.RUnlock() + + // if there's a LongestPrefix provider that's the only one + // we need to ask + if _, p, ok := r.providers.LongestPrefix(prefix); ok { + return []*keyedValueProvider{p.(*keyedValueProvider)} //nolint:forcetypeassert + } + + var providers []*keyedValueProvider + r.providers.WalkPrefix(prefix, func(key string, p interface{}) bool { + providers = append(providers, p.(*keyedValueProvider)) //nolint:forcetypeassert + return false + }) + + return providers +} + +// GetRegistrationKeys returns a list of all provider registration +// keys or prefixes. +func (r *Registry) GetRegistrationKeys() []string { + r.l.RLock() + defer r.l.RUnlock() + + var keys []string + + r.providers.Walk(func(key string, p interface{}) bool { + keys = append(keys, key) + return false + }) + return keys +} + +// asStorage returns a storage.Interface compatible struct +// that is backed by r. +func (r *Registry) asStorage() storage.Interface { + return &storageWrapper{ + Registry: r, + } +} + +func recovery(err *error) { + if x := recover(); x != nil { + if e, ok := x.(error); ok { + *err = e + return + } + + *err = fmt.Errorf("%v", x) + } +} diff --git a/base/runtime/registry_test.go b/base/runtime/registry_test.go new file mode 100644 index 000000000..54084e23a --- /dev/null +++ b/base/runtime/registry_test.go @@ -0,0 +1,157 @@ +package runtime + +import ( + "errors" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" +) + +type testRecord struct { + record.Base + sync.Mutex + Value string +} + +func makeTestRecord(key, value string) record.Record { + r := &testRecord{Value: value} + r.CreateMeta() + r.SetKey("runtime:" + key) + return r +} + +type testProvider struct { + k string + r []record.Record +} + +func (tp *testProvider) Get(key string) ([]record.Record, error) { + return tp.r, nil +} + +func (tp *testProvider) Set(r record.Record) (record.Record, error) { + return nil, errors.New("not implemented") +} + +func getTestRegistry(t *testing.T) *Registry { + t.Helper() + + r := NewRegistry() + + providers := []testProvider{ + { + k: "p1/", + r: []record.Record{ + makeTestRecord("p1/f1/v1", "p1.1"), + makeTestRecord("p1/f2/v2", "p1.2"), + makeTestRecord("p1/v3", "p1.3"), + }, + }, + { + k: "p2/f1", + r: []record.Record{ + makeTestRecord("p2/f1/v1", "p2.1"), + makeTestRecord("p2/f1/f2/v2", "p2.2"), + makeTestRecord("p2/f1/v3", "p2.3"), + }, + }, + } + + for idx := range providers { + p := providers[idx] + _, err := r.Register(p.k, &p) + require.NoError(t, err) + } + + return r +} + +func TestRegistryGet(t *testing.T) { + t.Parallel() + + var ( + r record.Record + err error + ) + + reg := getTestRegistry(t) + + r, err = reg.Get("p1/f1/v1") + require.NoError(t, err) + require.NotNil(t, r) + assert.Equal(t, "p1.1", r.(*testRecord).Value) //nolint:forcetypeassert + + r, err = reg.Get("p1/v3") + require.NoError(t, err) + require.NotNil(t, r) + assert.Equal(t, "p1.3", r.(*testRecord).Value) //nolint:forcetypeassert + + r, err = reg.Get("p1/v4") + require.Error(t, err) + assert.Nil(t, r) + + r, err = reg.Get("no-provider/foo") + require.Error(t, err) + assert.Nil(t, r) +} + +func TestRegistryQuery(t *testing.T) { + t.Parallel() + + reg := getTestRegistry(t) + + q := query.New("runtime:p") + iter, err := reg.Query(q, true, true) + require.NoError(t, err) + require.NotNil(t, iter) + var records []record.Record //nolint:prealloc + for r := range iter.Next { + records = append(records, r) + } + assert.Len(t, records, 6) + + q = query.New("runtime:p1/f") + iter, err = reg.Query(q, true, true) + require.NoError(t, err) + require.NotNil(t, iter) + records = nil + for r := range iter.Next { + records = append(records, r) + } + assert.Len(t, records, 2) +} + +func TestRegistryRegister(t *testing.T) { + t.Parallel() + + r := NewRegistry() + + cases := []struct { + inp string + err bool + }{ + {"runtime:foo/bar/bar", false}, + {"runtime:foo/bar/bar2", false}, + {"runtime:foo/bar", false}, + {"runtime:foo/bar", true}, // already used + {"runtime:foo/bar/", true}, // cannot register a prefix if there are providers below + {"runtime:foo/baz/", false}, + {"runtime:foo/baz2/", false}, + {"runtime:foo/baz3", false}, + {"runtime:foo/baz/bar", true}, + } + + for _, c := range cases { + _, err := r.Register(c.inp, nil) + if c.err { + assert.Error(t, err, c.inp) + } else { + assert.NoError(t, err, c.inp) + } + } +} diff --git a/base/runtime/singe_record_provider.go b/base/runtime/singe_record_provider.go new file mode 100644 index 000000000..e32dab8e1 --- /dev/null +++ b/base/runtime/singe_record_provider.go @@ -0,0 +1,44 @@ +package runtime + +import "github.com/safing/portmaster/base/database/record" + +// singleRecordReader is a convenience type for read-only exposing +// a single record.Record. Note that users must lock the whole record +// themself before performing any manipulation on the record. +type singleRecordReader struct { + record.Record +} + +// ProvideRecord returns a ValueProvider the exposes read-only +// access to r. Users of ProvideRecord need to ensure the lock +// the whole record before performing modifications on it. +// +// Example: +// +// type MyValue struct { +// record.Base +// Value string +// } +// r := new(MyValue) +// pushUpdate, _ := runtime.Register("my/key", ProvideRecord(r)) +// r.Lock() +// r.Value = "foobar" +// pushUpdate(r) +// r.Unlock() +func ProvideRecord(r record.Record) ValueProvider { + return &singleRecordReader{r} +} + +// Set implements ValueProvider.Set and returns ErrReadOnly. +func (sr *singleRecordReader) Set(_ record.Record) (record.Record, error) { + return nil, ErrReadOnly +} + +// Get implements ValueProvider.Get and returns the wrapped record.Record +// but only if keyOrPrefix exactly matches the records database key. +func (sr *singleRecordReader) Get(keyOrPrefix string) ([]record.Record, error) { + if keyOrPrefix != sr.Record.DatabaseKey() { + return nil, nil + } + return []record.Record{sr.Record}, nil +} diff --git a/base/runtime/storage.go b/base/runtime/storage.go new file mode 100644 index 000000000..fdcc8cd8e --- /dev/null +++ b/base/runtime/storage.go @@ -0,0 +1,32 @@ +package runtime + +import ( + "github.com/safing/portmaster/base/database/iterator" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/database/storage" +) + +// storageWrapper is a simple wrapper around storage.InjectBase and +// Registry and make sure the supported methods are handled by +// the registry rather than the InjectBase defaults. +// storageWrapper is mainly there to keep the method landscape of +// Registry as small as possible. +type storageWrapper struct { + storage.InjectBase + Registry *Registry +} + +func (sw *storageWrapper) Get(key string) (record.Record, error) { + return sw.Registry.Get(key) +} + +func (sw *storageWrapper) Put(r record.Record) (record.Record, error) { + return sw.Registry.Put(r) +} + +func (sw *storageWrapper) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { + return sw.Registry.Query(q, local, internal) +} + +func (sw *storageWrapper) ReadOnly() bool { return false } diff --git a/base/runtime/trace_provider.go b/base/runtime/trace_provider.go new file mode 100644 index 000000000..2bb2c2cbc --- /dev/null +++ b/base/runtime/trace_provider.go @@ -0,0 +1,37 @@ +package runtime + +import ( + "time" + + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/log" +) + +// traceValueProvider can be used to wrap an +// existing value provider to trace an calls to +// their Set and Get methods. +type traceValueProvider struct { + ValueProvider +} + +// TraceProvider returns a new ValueProvider that wraps +// vp but traces all Set and Get methods calls. +func TraceProvider(vp ValueProvider) ValueProvider { + return &traceValueProvider{vp} +} + +func (tvp *traceValueProvider) Set(r record.Record) (res record.Record, err error) { + defer func(start time.Time) { + log.Tracef("runtime: setting record %q: duration=%s err=%v", r.Key(), time.Since(start), err) + }(time.Now()) + + return tvp.ValueProvider.Set(r) +} + +func (tvp *traceValueProvider) Get(keyOrPrefix string) (records []record.Record, err error) { + defer func(start time.Time) { + log.Tracef("runtime: loading records %q: duration=%s err=%v #records=%d", keyOrPrefix, time.Since(start), err, len(records)) + }(time.Now()) + + return tvp.ValueProvider.Get(keyOrPrefix) +} diff --git a/base/template/module.go b/base/template/module.go new file mode 100644 index 000000000..a46a74edc --- /dev/null +++ b/base/template/module.go @@ -0,0 +1,111 @@ +package template + +import ( + "context" + "time" + + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/modules/subsystems" +) + +const ( + eventStateUpdate = "state update" +) + +var module *modules.Module + +func init() { + // register module + module = modules.Register("template", prep, start, stop) // add dependencies... + subsystems.Register( + "template-subsystem", // ID + "Template Subsystem", // name + "This subsystem is a template for quick setup", // description + module, + "config:template", // key space for configuration options registered + &config.Option{ + Name: "Template Subsystem", + Key: "config:subsystems/template", + Description: "This option enables the Template Subsystem [TEMPLATE]", + OptType: config.OptTypeBool, + DefaultValue: false, + }, + ) + + // register events that other modules can subscribe to + module.RegisterEvent(eventStateUpdate, true) +} + +func prep() error { + // register options + err := config.Register(&config.Option{ + Name: "language", + Key: "template/language", + Description: "Sets the language for the template [TEMPLATE]", + OptType: config.OptTypeString, + ExpertiseLevel: config.ExpertiseLevelUser, // default + ReleaseLevel: config.ReleaseLevelStable, // default + RequiresRestart: false, // default + DefaultValue: "en", + ValidationRegex: "^[a-z]{2}$", + }) + if err != nil { + return err + } + + // register event hooks + // do this in prep() and not in start(), as we don't want to register again if module is turned off and on again + err = module.RegisterEventHook( + "template", // event source module name + "state update", // event source name + "react to state changes", // description of hook function + eventHandler, // hook function + ) + if err != nil { + return err + } + + // hint: event hooks and tasks will not be run if module isn't online + return nil +} + +func start() error { + // register tasks + module.NewTask("do something", taskFn).Queue() + + // start service worker + module.StartServiceWorker("do something", 0, serviceWorker) + + return nil +} + +func stop() error { + return nil +} + +func serviceWorker(ctx context.Context) error { + for { + select { + case <-time.After(1 * time.Second): + err := do() + if err != nil { + return err + } + case <-ctx.Done(): + return nil + } + } +} + +func taskFn(ctx context.Context, task *modules.Task) error { + return do() +} + +func eventHandler(ctx context.Context, data interface{}) error { + return do() +} + +func do() error { + return nil +} diff --git a/base/template/module_test.go b/base/template/module_test.go new file mode 100644 index 000000000..2a41f02b9 --- /dev/null +++ b/base/template/module_test.go @@ -0,0 +1,54 @@ +package template + +import ( + "fmt" + "os" + "testing" + + _ "github.com/safing/portmaster/base/database/dbmodule" + "github.com/safing/portmaster/base/dataroot" + "github.com/safing/portmaster/base/modules" +) + +func TestMain(m *testing.M) { + // register base module, for database initialization + modules.Register("base", nil, nil, nil) + + // enable module for testing + module.Enable() + + // tmp dir for data root (db & config) + tmpDir, err := os.MkdirTemp("", "portbase-testing-") + if err != nil { + fmt.Fprintf(os.Stderr, "failed to create tmp dir: %s\n", err) + os.Exit(1) + } + // initialize data dir + err = dataroot.Initialize(tmpDir, 0o0755) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to initialize data root: %s\n", err) + os.Exit(1) + } + + // start modules + var exitCode int + err = modules.Start() + if err != nil { + // starting failed + fmt.Fprintf(os.Stderr, "failed to setup test: %s\n", err) + exitCode = 1 + } else { + // run tests + exitCode = m.Run() + } + + // shutdown + _ = modules.Shutdown() + if modules.GetExitStatusCode() != 0 { + exitCode = modules.GetExitStatusCode() + fmt.Fprintf(os.Stderr, "failed to cleanly shutdown test: %s\n", err) + } + // clean up and exit + _ = os.RemoveAll(tmpDir) + os.Exit(exitCode) +} diff --git a/base/updater/doc.go b/base/updater/doc.go new file mode 100644 index 000000000..829a5bd38 --- /dev/null +++ b/base/updater/doc.go @@ -0,0 +1,2 @@ +// Package updater is an update registry that manages updates and versions. +package updater diff --git a/base/updater/export.go b/base/updater/export.go new file mode 100644 index 000000000..55b64a3f3 --- /dev/null +++ b/base/updater/export.go @@ -0,0 +1,15 @@ +package updater + +// Export exports the list of resources. +func (reg *ResourceRegistry) Export() map[string]*Resource { + reg.RLock() + defer reg.RUnlock() + + // copy the map + copiedResources := make(map[string]*Resource) + for key, val := range reg.resources { + copiedResources[key] = val.Export() + } + + return copiedResources +} diff --git a/base/updater/fetch.go b/base/updater/fetch.go new file mode 100644 index 000000000..e3c397e47 --- /dev/null +++ b/base/updater/fetch.go @@ -0,0 +1,348 @@ +package updater + +import ( + "bytes" + "context" + "errors" + "fmt" + "hash" + "io" + "net/http" + "net/url" + "os" + "path" + "path/filepath" + "time" + + "github.com/safing/jess/filesig" + "github.com/safing/jess/lhash" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/utils/renameio" +) + +func (reg *ResourceRegistry) fetchFile(ctx context.Context, client *http.Client, rv *ResourceVersion, tries int) error { + // backoff when retrying + if tries > 0 { + select { + case <-ctx.Done(): + return nil // module is shutting down + case <-time.After(time.Duration(tries*tries) * time.Second): + } + } + + // check destination dir + dirPath := filepath.Dir(rv.storagePath()) + err := reg.storageDir.EnsureAbsPath(dirPath) + if err != nil { + return fmt.Errorf("could not create updates folder: %s", dirPath) + } + + // If verification is enabled, download signature first. + var ( + verifiedHash *lhash.LabeledHash + sigFileData []byte + ) + if rv.resource.VerificationOptions != nil { + verifiedHash, sigFileData, err = reg.fetchAndVerifySigFile( + ctx, client, + rv.resource.VerificationOptions, + rv.versionedSigPath(), rv.SigningMetadata(), + tries, + ) + + if err != nil { + switch rv.resource.VerificationOptions.DownloadPolicy { + case SignaturePolicyRequire: + return fmt.Errorf("signature verification failed: %w", err) + case SignaturePolicyWarn: + log.Warningf("%s: failed to verify downloaded signature of %s: %s", reg.Name, rv.versionedPath(), err) + case SignaturePolicyDisable: + log.Debugf("%s: failed to verify downloaded signature of %s: %s", reg.Name, rv.versionedPath(), err) + } + } + } + + // open file for writing + atomicFile, err := renameio.TempFile(reg.tmpDir.Path, rv.storagePath()) + if err != nil { + return fmt.Errorf("could not create temp file for download: %w", err) + } + defer atomicFile.Cleanup() //nolint:errcheck // ignore error for now, tmp dir will be cleaned later again anyway + + // start file download + resp, downloadURL, err := reg.makeRequest(ctx, client, rv.versionedPath(), tries) + if err != nil { + return err + } + defer func() { + _ = resp.Body.Close() + }() + + // Write to the hasher at the same time, if needed. + var hasher hash.Hash + var writeDst io.Writer = atomicFile + if verifiedHash != nil { + hasher = verifiedHash.Algorithm().RawHasher() + writeDst = io.MultiWriter(hasher, atomicFile) + } + + // Download and write file. + n, err := io.Copy(writeDst, resp.Body) + if err != nil { + return fmt.Errorf("failed to download %q: %w", downloadURL, err) + } + if resp.ContentLength != n { + return fmt.Errorf("failed to finish download of %q: written %d out of %d bytes", downloadURL, n, resp.ContentLength) + } + + // Before file is finalized, check if hash, if available. + if hasher != nil { + downloadDigest := hasher.Sum(nil) + if verifiedHash.EqualRaw(downloadDigest) { + log.Infof("%s: verified signature of %s", reg.Name, downloadURL) + } else { + switch rv.resource.VerificationOptions.DownloadPolicy { + case SignaturePolicyRequire: + return errors.New("file does not match signed checksum") + case SignaturePolicyWarn: + log.Warningf("%s: checksum does not match file from %s", reg.Name, downloadURL) + case SignaturePolicyDisable: + log.Debugf("%s: checksum does not match file from %s", reg.Name, downloadURL) + } + + // Reset hasher to signal that the sig should not be written. + hasher = nil + } + } + + // Write signature file, if we have one and if verification succeeded. + if len(sigFileData) > 0 && hasher != nil { + sigFilePath := rv.storagePath() + filesig.Extension + err := os.WriteFile(sigFilePath, sigFileData, 0o0644) //nolint:gosec + if err != nil { + switch rv.resource.VerificationOptions.DownloadPolicy { + case SignaturePolicyRequire: + return fmt.Errorf("failed to write signature file %s: %w", sigFilePath, err) + case SignaturePolicyWarn: + log.Warningf("%s: failed to write signature file %s: %s", reg.Name, sigFilePath, err) + case SignaturePolicyDisable: + log.Debugf("%s: failed to write signature file %s: %s", reg.Name, sigFilePath, err) + } + } + } + + // finalize file + err = atomicFile.CloseAtomicallyReplace() + if err != nil { + return fmt.Errorf("%s: failed to finalize file %s: %w", reg.Name, rv.storagePath(), err) + } + // set permissions + if !onWindows { + // TODO: only set executable files to 0755, set other to 0644 + err = os.Chmod(rv.storagePath(), 0o0755) //nolint:gosec // See TODO above. + if err != nil { + log.Warningf("%s: failed to set permissions on downloaded file %s: %s", reg.Name, rv.storagePath(), err) + } + } + + log.Debugf("%s: fetched %s and stored to %s", reg.Name, downloadURL, rv.storagePath()) + return nil +} + +func (reg *ResourceRegistry) fetchMissingSig(ctx context.Context, client *http.Client, rv *ResourceVersion, tries int) error { + // backoff when retrying + if tries > 0 { + select { + case <-ctx.Done(): + return nil // module is shutting down + case <-time.After(time.Duration(tries*tries) * time.Second): + } + } + + // Check destination dir. + dirPath := filepath.Dir(rv.storagePath()) + err := reg.storageDir.EnsureAbsPath(dirPath) + if err != nil { + return fmt.Errorf("could not create updates folder: %s", dirPath) + } + + // Download and verify the missing signature. + verifiedHash, sigFileData, err := reg.fetchAndVerifySigFile( + ctx, client, + rv.resource.VerificationOptions, + rv.versionedSigPath(), rv.SigningMetadata(), + tries, + ) + if err != nil { + switch rv.resource.VerificationOptions.DownloadPolicy { + case SignaturePolicyRequire: + return fmt.Errorf("signature verification failed: %w", err) + case SignaturePolicyWarn: + log.Warningf("%s: failed to verify downloaded signature of %s: %s", reg.Name, rv.versionedPath(), err) + case SignaturePolicyDisable: + log.Debugf("%s: failed to verify downloaded signature of %s: %s", reg.Name, rv.versionedPath(), err) + } + return nil + } + + // Check if the signature matches the resource file. + ok, err := verifiedHash.MatchesFile(rv.storagePath()) + if err != nil { + switch rv.resource.VerificationOptions.DownloadPolicy { + case SignaturePolicyRequire: + return fmt.Errorf("error while verifying resource file: %w", err) + case SignaturePolicyWarn: + log.Warningf("%s: error while verifying resource file %s", reg.Name, rv.storagePath()) + case SignaturePolicyDisable: + log.Debugf("%s: error while verifying resource file %s", reg.Name, rv.storagePath()) + } + return nil + } + if !ok { + switch rv.resource.VerificationOptions.DownloadPolicy { + case SignaturePolicyRequire: + return errors.New("resource file does not match signed checksum") + case SignaturePolicyWarn: + log.Warningf("%s: checksum does not match resource file from %s", reg.Name, rv.storagePath()) + case SignaturePolicyDisable: + log.Debugf("%s: checksum does not match resource file from %s", reg.Name, rv.storagePath()) + } + return nil + } + + // Write signature file. + err = os.WriteFile(rv.storageSigPath(), sigFileData, 0o0644) //nolint:gosec + if err != nil { + switch rv.resource.VerificationOptions.DownloadPolicy { + case SignaturePolicyRequire: + return fmt.Errorf("failed to write signature file %s: %w", rv.storageSigPath(), err) + case SignaturePolicyWarn: + log.Warningf("%s: failed to write signature file %s: %s", reg.Name, rv.storageSigPath(), err) + case SignaturePolicyDisable: + log.Debugf("%s: failed to write signature file %s: %s", reg.Name, rv.storageSigPath(), err) + } + } + + log.Debugf("%s: fetched %s and stored to %s", reg.Name, rv.versionedSigPath(), rv.storageSigPath()) + return nil +} + +func (reg *ResourceRegistry) fetchAndVerifySigFile(ctx context.Context, client *http.Client, verifOpts *VerificationOptions, sigFilePath string, requiredMetadata map[string]string, tries int) (*lhash.LabeledHash, []byte, error) { + // Download signature file. + resp, _, err := reg.makeRequest(ctx, client, sigFilePath, tries) + if err != nil { + return nil, nil, err + } + defer func() { + _ = resp.Body.Close() + }() + sigFileData, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, err + } + + // Extract all signatures. + sigs, err := filesig.ParseSigFile(sigFileData) + switch { + case len(sigs) == 0 && err != nil: + return nil, nil, fmt.Errorf("failed to parse signature file: %w", err) + case len(sigs) == 0: + return nil, nil, errors.New("no signatures found in signature file") + case err != nil: + return nil, nil, fmt.Errorf("failed to parse signature file: %w", err) + } + + // Verify all signatures. + var verifiedHash *lhash.LabeledHash + for _, sig := range sigs { + fd, err := filesig.VerifyFileData( + sig, + requiredMetadata, + verifOpts.TrustStore, + ) + if err != nil { + return nil, sigFileData, err + } + + // Save or check verified hash. + if verifiedHash == nil { + verifiedHash = fd.FileHash() + } else if !fd.FileHash().Equal(verifiedHash) { + // Return an error if two valid hashes mismatch. + // For simplicity, all hash algorithms must be the same for now. + return nil, sigFileData, errors.New("file hashes from different signatures do not match") + } + } + + return verifiedHash, sigFileData, nil +} + +func (reg *ResourceRegistry) fetchData(ctx context.Context, client *http.Client, downloadPath string, tries int) (fileData []byte, downloadedFrom string, err error) { + // backoff when retrying + if tries > 0 { + select { + case <-ctx.Done(): + return nil, "", nil // module is shutting down + case <-time.After(time.Duration(tries*tries) * time.Second): + } + } + + // start file download + resp, downloadURL, err := reg.makeRequest(ctx, client, downloadPath, tries) + if err != nil { + return nil, downloadURL, err + } + defer func() { + _ = resp.Body.Close() + }() + + // download and write file + buf := bytes.NewBuffer(make([]byte, 0, resp.ContentLength)) + n, err := io.Copy(buf, resp.Body) + if err != nil { + return nil, downloadURL, fmt.Errorf("failed to download %q: %w", downloadURL, err) + } + if resp.ContentLength != n { + return nil, downloadURL, fmt.Errorf("failed to finish download of %q: written %d out of %d bytes", downloadURL, n, resp.ContentLength) + } + + return buf.Bytes(), downloadURL, nil +} + +func (reg *ResourceRegistry) makeRequest(ctx context.Context, client *http.Client, downloadPath string, tries int) (resp *http.Response, downloadURL string, err error) { + // parse update URL + updateBaseURL := reg.UpdateURLs[tries%len(reg.UpdateURLs)] + u, err := url.Parse(updateBaseURL) + if err != nil { + return nil, "", fmt.Errorf("failed to parse update URL %q: %w", updateBaseURL, err) + } + // add download path + u.Path = path.Join(u.Path, downloadPath) + // compile URL + downloadURL = u.String() + + // create request + req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadURL, http.NoBody) + if err != nil { + return nil, "", fmt.Errorf("failed to create request for %q: %w", downloadURL, err) + } + + // set user agent + if reg.UserAgent != "" { + req.Header.Set("User-Agent", reg.UserAgent) + } + + // start request + resp, err = client.Do(req) + if err != nil { + return nil, "", fmt.Errorf("failed to make request to %q: %w", downloadURL, err) + } + + // check return code + if resp.StatusCode != http.StatusOK { + _ = resp.Body.Close() + return nil, "", fmt.Errorf("failed to fetch %q: %d %s", downloadURL, resp.StatusCode, resp.Status) + } + + return resp, downloadURL, err +} diff --git a/base/updater/file.go b/base/updater/file.go new file mode 100644 index 000000000..90b7d3565 --- /dev/null +++ b/base/updater/file.go @@ -0,0 +1,156 @@ +package updater + +import ( + "errors" + "io" + "io/fs" + "os" + "strings" + + semver "github.com/hashicorp/go-version" + + "github.com/safing/jess/filesig" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/utils" +) + +// File represents a file from the update system. +type File struct { + resource *Resource + version *ResourceVersion + notifier *notifier + versionedPath string + storagePath string +} + +// Identifier returns the identifier of the file. +func (file *File) Identifier() string { + return file.resource.Identifier +} + +// Version returns the version of the file. +func (file *File) Version() string { + return file.version.VersionNumber +} + +// SemVer returns the semantic version of the file. +func (file *File) SemVer() *semver.Version { + return file.version.semVer +} + +// EqualsVersion normalizes the given version and checks equality with semver. +func (file *File) EqualsVersion(version string) bool { + return file.version.EqualsVersion(version) +} + +// Path returns the absolute filepath of the file. +func (file *File) Path() string { + return file.storagePath +} + +// SigningMetadata returns the metadata to be included in signatures. +func (file *File) SigningMetadata() map[string]string { + return map[string]string{ + "id": file.Identifier(), + "version": file.Version(), + } +} + +// Verify verifies the given file. +func (file *File) Verify() ([]*filesig.FileData, error) { + // Check if verification is configured. + if file.resource.VerificationOptions == nil { + return nil, ErrVerificationNotConfigured + } + + // Verify file. + fileData, err := filesig.VerifyFile( + file.storagePath, + file.storagePath+filesig.Extension, + file.SigningMetadata(), + file.resource.VerificationOptions.TrustStore, + ) + if err != nil { + switch file.resource.VerificationOptions.DiskLoadPolicy { + case SignaturePolicyRequire: + return nil, err + case SignaturePolicyWarn: + log.Warningf("%s: failed to verify %s: %s", file.resource.registry.Name, file.storagePath, err) + case SignaturePolicyDisable: + log.Debugf("%s: failed to verify %s: %s", file.resource.registry.Name, file.storagePath, err) + } + } + + return fileData, nil +} + +// Blacklist notifies the update system that this file is somehow broken, and should be ignored from now on, until restarted. +func (file *File) Blacklist() error { + return file.resource.Blacklist(file.version.VersionNumber) +} + +// markActiveWithLocking marks the file as active, locking the resource in the process. +func (file *File) markActiveWithLocking() { + file.resource.Lock() + defer file.resource.Unlock() + + // update last used version + if file.resource.ActiveVersion != file.version { + log.Debugf("updater: setting active version of resource %s from %s to %s", file.resource.Identifier, file.resource.ActiveVersion, file.version.VersionNumber) + file.resource.ActiveVersion = file.version + } +} + +// Unpacker describes the function that is passed to +// File.Unpack. It receives a reader to the compressed/packed +// file and should return a reader that provides +// unpacked file contents. If the returned reader implements +// io.Closer it's close method is invoked when an error +// or io.EOF is returned from Read(). +type Unpacker func(io.Reader) (io.Reader, error) + +// Unpack returns the path to the unpacked version of file and +// unpacks it on demand using unpacker. +func (file *File) Unpack(suffix string, unpacker Unpacker) (string, error) { + path := strings.TrimSuffix(file.Path(), suffix) + + if suffix == "" { + path += "-unpacked" + } + + _, err := os.Stat(path) + if err == nil { + return path, nil + } + + if !errors.Is(err, fs.ErrNotExist) { + return "", err + } + + f, err := os.Open(file.Path()) + if err != nil { + return "", err + } + defer func() { + _ = f.Close() + }() + + r, err := unpacker(f) + if err != nil { + return "", err + } + + ioErr := utils.CreateAtomic(path, r, &utils.AtomicFileOptions{ + TempDir: file.resource.registry.TmpDir().Path, + }) + + if c, ok := r.(io.Closer); ok { + if err := c.Close(); err != nil && ioErr == nil { + // if ioErr is already set we ignore the error from + // closing the unpacker. + ioErr = err + } + } + + return path, ioErr +} diff --git a/base/updater/filename.go b/base/updater/filename.go new file mode 100644 index 000000000..69e9db00e --- /dev/null +++ b/base/updater/filename.go @@ -0,0 +1,57 @@ +package updater + +import ( + "path" + "regexp" + "strings" +) + +var ( + fileVersionRegex = regexp.MustCompile(`_v[0-9]+-[0-9]+-[0-9]+(-[a-z]+)?`) + rawVersionRegex = regexp.MustCompile(`^[0-9]+\.[0-9]+\.[0-9]+(-[a-z]+)?$`) +) + +// GetIdentifierAndVersion splits the given file path into its identifier and version. +func GetIdentifierAndVersion(versionedPath string) (identifier, version string, ok bool) { + dirPath, filename := path.Split(versionedPath) + + // Extract version from filename. + rawVersion := fileVersionRegex.FindString(filename) + if rawVersion == "" { + // No version present in file, making it invalid. + return "", "", false + } + + // Trim the `_v` that gets caught by the regex and + // replace `-` with `.` to get the version string. + version = strings.Replace(strings.TrimLeft(rawVersion, "_v"), "-", ".", 2) + + // Put the filename back together without version. + i := strings.Index(filename, rawVersion) + if i < 0 { + // extracted version not in string (impossible) + return "", "", false + } + filename = filename[:i] + filename[i+len(rawVersion):] + + // Put the full path back together and return it. + // `dirPath + filename` is guaranteed by path.Split() + return dirPath + filename, version, true +} + +// GetVersionedPath combines the identifier and version and returns it as a file path. +func GetVersionedPath(identifier, version string) (versionedPath string) { + identifierPath, filename := path.Split(identifier) + + // Split the filename where the version should go. + splittedFilename := strings.SplitN(filename, ".", 2) + // Replace `.` with `-` for the filename format. + transformedVersion := strings.Replace(version, ".", "-", 2) + + // Put everything back together and return it. + versionedPath = identifierPath + splittedFilename[0] + "_v" + transformedVersion + if len(splittedFilename) > 1 { + versionedPath += "." + splittedFilename[1] + } + return versionedPath +} diff --git a/base/updater/filename_test.go b/base/updater/filename_test.go new file mode 100644 index 000000000..cf5fb9224 --- /dev/null +++ b/base/updater/filename_test.go @@ -0,0 +1,80 @@ +package updater + +import ( + "regexp" + "testing" + + "github.com/stretchr/testify/assert" +) + +func testRegexMatch(t *testing.T, testRegex *regexp.Regexp, testString string, shouldMatch bool) { + t.Helper() + + if testRegex.MatchString(testString) != shouldMatch { + if shouldMatch { + t.Errorf("regex %s should match %s", testRegex, testString) + } else { + t.Errorf("regex %s should not match %s", testRegex, testString) + } + } +} + +func testRegexFind(t *testing.T, testRegex *regexp.Regexp, testString string, shouldMatch bool) { + t.Helper() + + if (testRegex.FindString(testString) != "") != shouldMatch { + if shouldMatch { + t.Errorf("regex %s should find %s", testRegex, testString) + } else { + t.Errorf("regex %s should not find %s", testRegex, testString) + } + } +} + +func testVersionTransformation(t *testing.T, testFilename, testIdentifier, testVersion string) { + t.Helper() + + identifier, version, ok := GetIdentifierAndVersion(testFilename) + if !ok { + t.Errorf("failed to get identifier and version of %s", testFilename) + } + assert.Equal(t, testIdentifier, identifier, "identifier does not match") + assert.Equal(t, testVersion, version, "version does not match") + + versionedPath := GetVersionedPath(testIdentifier, testVersion) + assert.Equal(t, testFilename, versionedPath, "filename (versioned path) does not match") +} + +func TestRegexes(t *testing.T) { + t.Parallel() + + testRegexMatch(t, rawVersionRegex, "0.1.2", true) + testRegexMatch(t, rawVersionRegex, "0.1.2-beta", true) + testRegexMatch(t, rawVersionRegex, "0.1.2-staging", true) + testRegexMatch(t, rawVersionRegex, "12.13.14", true) + + testRegexMatch(t, rawVersionRegex, "v0.1.2", false) + testRegexMatch(t, rawVersionRegex, "0.", false) + testRegexMatch(t, rawVersionRegex, "0.1", false) + testRegexMatch(t, rawVersionRegex, "0.1.", false) + testRegexMatch(t, rawVersionRegex, ".1.2", false) + testRegexMatch(t, rawVersionRegex, ".1.", false) + testRegexMatch(t, rawVersionRegex, "012345", false) + + testRegexFind(t, fileVersionRegex, "/path/to/file_v0-0-0", true) + testRegexFind(t, fileVersionRegex, "/path/to/file_v1-2-3", true) + testRegexFind(t, fileVersionRegex, "/path/to/file_v1-2-3.exe", true) + + testRegexFind(t, fileVersionRegex, "/path/to/file-v1-2-3", false) + testRegexFind(t, fileVersionRegex, "/path/to/file_v1.2.3", false) + testRegexFind(t, fileVersionRegex, "/path/to/file_1-2-3", false) + testRegexFind(t, fileVersionRegex, "/path/to/file_v1-2", false) + testRegexFind(t, fileVersionRegex, "/path/to/file-v1-2-3", false) + + testVersionTransformation(t, "/path/to/file_v0-0-0", "/path/to/file", "0.0.0") + testVersionTransformation(t, "/path/to/file_v1-2-3", "/path/to/file", "1.2.3") + testVersionTransformation(t, "/path/to/file_v1-2-3-beta", "/path/to/file", "1.2.3-beta") + testVersionTransformation(t, "/path/to/file_v1-2-3-staging", "/path/to/file", "1.2.3-staging") + testVersionTransformation(t, "/path/to/file_v1-2-3.exe", "/path/to/file.exe", "1.2.3") + testVersionTransformation(t, "/path/to/file_v1-2-3-staging.exe", "/path/to/file.exe", "1.2.3-staging") +} diff --git a/base/updater/get.go b/base/updater/get.go new file mode 100644 index 000000000..eb09ba98f --- /dev/null +++ b/base/updater/get.go @@ -0,0 +1,91 @@ +package updater + +import ( + "context" + "errors" + "fmt" + "net/http" + + "github.com/safing/portmaster/base/log" +) + +// Errors returned by the updater package. +var ( + ErrNotFound = errors.New("the requested file could not be found") + ErrNotAvailableLocally = errors.New("the requested file is not available locally") + ErrVerificationNotConfigured = errors.New("verification not configured for this resource") +) + +// GetFile returns the selected (mostly newest) file with the given +// identifier or an error, if it fails. +func (reg *ResourceRegistry) GetFile(identifier string) (*File, error) { + reg.RLock() + res, ok := reg.resources[identifier] + reg.RUnlock() + if !ok { + return nil, ErrNotFound + } + + file := res.GetFile() + // check if file is available locally + if file.version.Available { + file.markActiveWithLocking() + + // Verify file, if configured. + _, err := file.Verify() + if err != nil && !errors.Is(err, ErrVerificationNotConfigured) { + // TODO: If verification is required, try deleting the resource and downloading it again. + return nil, fmt.Errorf("failed to verify file: %w", err) + } + + return file, nil + } + + // check if online + if !reg.Online { + return nil, ErrNotAvailableLocally + } + + // check download dir + err := reg.tmpDir.Ensure() + if err != nil { + return nil, fmt.Errorf("could not prepare tmp directory for download: %w", err) + } + + // Start registry operation. + reg.state.StartOperation(StateFetching) + defer reg.state.EndOperation() + + // download file + log.Tracef("%s: starting download of %s", reg.Name, file.versionedPath) + client := &http.Client{} + for tries := 0; tries < 5; tries++ { + err = reg.fetchFile(context.TODO(), client, file.version, tries) + if err != nil { + log.Tracef("%s: failed to download %s: %s, retrying (%d)", reg.Name, file.versionedPath, err, tries+1) + } else { + file.markActiveWithLocking() + + // TODO: We just download the file - should we verify it again? + return file, nil + } + } + log.Warningf("%s: failed to download %s: %s", reg.Name, file.versionedPath, err) + return nil, err +} + +// GetVersion returns the selected version of the given identifier. +// The returned resource version may not be modified. +func (reg *ResourceRegistry) GetVersion(identifier string) (*ResourceVersion, error) { + reg.RLock() + res, ok := reg.resources[identifier] + reg.RUnlock() + if !ok { + return nil, ErrNotFound + } + + res.Lock() + defer res.Unlock() + + return res.SelectedVersion, nil +} diff --git a/base/updater/indexes.go b/base/updater/indexes.go new file mode 100644 index 000000000..81a373a33 --- /dev/null +++ b/base/updater/indexes.go @@ -0,0 +1,109 @@ +package updater + +import ( + "encoding/json" + "errors" + "fmt" + "time" +) + +const ( + baseIndexExtension = ".json" + v2IndexExtension = ".v2.json" +) + +// Index describes an index file pulled by the updater. +type Index struct { + // Path is the path to the index file + // on the update server. + Path string + + // Channel holds the release channel name of the index. + // It must match the filename without extension. + Channel string + + // PreRelease signifies that all versions of this index should be marked as + // pre-releases, no matter if the versions actually have a pre-release tag or + // not. + PreRelease bool + + // AutoDownload specifies whether new versions should be automatically downloaded. + AutoDownload bool + + // LastRelease holds the time of the last seen release of this index. + LastRelease time.Time +} + +// IndexFile represents an index file. +type IndexFile struct { + Channel string + Published time.Time + + Releases map[string]string +} + +var ( + // ErrIndexChecksumMismatch is returned when an index does not match its + // signed checksum. + ErrIndexChecksumMismatch = errors.New("index checksum does mot match signature") + + // ErrIndexFromFuture is returned when an index is parsed with a + // Published timestamp that lies in the future. + ErrIndexFromFuture = errors.New("index is from the future") + + // ErrIndexIsOlder is returned when an index is parsed with an older + // Published timestamp than the current Published timestamp. + ErrIndexIsOlder = errors.New("index is older than the current one") + + // ErrIndexChannelMismatch is returned when an index is parsed with a + // different channel that the expected one. + ErrIndexChannelMismatch = errors.New("index does not match the expected channel") +) + +// ParseIndexFile parses an index file and checks if it is valid. +func ParseIndexFile(indexData []byte, channel string, lastIndexRelease time.Time) (*IndexFile, error) { + // Load into struct. + indexFile := &IndexFile{} + err := json.Unmarshal(indexData, indexFile) + if err != nil { + return nil, fmt.Errorf("failed to parse signed index data: %w", err) + } + + // Fallback to old format if there are no releases and no channel is defined. + // TODO: Remove in v1 + if len(indexFile.Releases) == 0 && indexFile.Channel == "" { + return loadOldIndexFormat(indexData, channel) + } + + // Check the index metadata. + switch { + case !indexFile.Published.IsZero() && time.Now().Before(indexFile.Published): + return indexFile, ErrIndexFromFuture + + case !indexFile.Published.IsZero() && + !lastIndexRelease.IsZero() && + lastIndexRelease.After(indexFile.Published): + return indexFile, ErrIndexIsOlder + + case channel != "" && + indexFile.Channel != "" && + channel != indexFile.Channel: + return indexFile, ErrIndexChannelMismatch + } + + return indexFile, nil +} + +func loadOldIndexFormat(indexData []byte, channel string) (*IndexFile, error) { + releases := make(map[string]string) + err := json.Unmarshal(indexData, &releases) + if err != nil { + return nil, err + } + + return &IndexFile{ + Channel: channel, + // Do NOT define `Published`, as this would break the "is newer" check. + Releases: releases, + }, nil +} diff --git a/base/updater/indexes_test.go b/base/updater/indexes_test.go new file mode 100644 index 000000000..a85046cda --- /dev/null +++ b/base/updater/indexes_test.go @@ -0,0 +1,57 @@ +package updater + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +var ( + oldFormat = `{ + "all/ui/modules/assets.zip": "0.3.0", + "all/ui/modules/portmaster.zip": "0.2.4", + "linux_amd64/core/portmaster-core": "0.8.13" +}` + + newFormat = `{ + "Channel": "stable", + "Published": "2022-01-02T00:00:00Z", + "Releases": { + "all/ui/modules/assets.zip": "0.3.0", + "all/ui/modules/portmaster.zip": "0.2.4", + "linux_amd64/core/portmaster-core": "0.8.13" + } +}` + + formatTestChannel = "stable" + formatTestReleases = map[string]string{ + "all/ui/modules/assets.zip": "0.3.0", + "all/ui/modules/portmaster.zip": "0.2.4", + "linux_amd64/core/portmaster-core": "0.8.13", + } +) + +func TestIndexParsing(t *testing.T) { + t.Parallel() + + lastRelease, err := time.Parse(time.RFC3339, "2022-01-01T00:00:00Z") + if err != nil { + t.Fatal(err) + } + + oldIndexFile, err := ParseIndexFile([]byte(oldFormat), formatTestChannel, lastRelease) + if err != nil { + t.Fatal(err) + } + + newIndexFile, err := ParseIndexFile([]byte(newFormat), formatTestChannel, lastRelease) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, formatTestChannel, oldIndexFile.Channel, "channel should be the same") + assert.Equal(t, formatTestChannel, newIndexFile.Channel, "channel should be the same") + assert.Equal(t, formatTestReleases, oldIndexFile.Releases, "releases should be the same") + assert.Equal(t, formatTestReleases, newIndexFile.Releases, "releases should be the same") +} diff --git a/base/updater/notifier.go b/base/updater/notifier.go new file mode 100644 index 000000000..66b2832df --- /dev/null +++ b/base/updater/notifier.go @@ -0,0 +1,33 @@ +package updater + +import ( + "github.com/tevino/abool" +) + +type notifier struct { + upgradeAvailable *abool.AtomicBool + notifyChannel chan struct{} +} + +func newNotifier() *notifier { + return ¬ifier{ + upgradeAvailable: abool.NewBool(false), + notifyChannel: make(chan struct{}), + } +} + +func (n *notifier) markAsUpgradeable() { + if n.upgradeAvailable.SetToIf(false, true) { + close(n.notifyChannel) + } +} + +// UpgradeAvailable returns whether an upgrade is available for this file. +func (file *File) UpgradeAvailable() bool { + return file.notifier.upgradeAvailable.IsSet() +} + +// WaitForAvailableUpgrade blocks (selectable) until an upgrade for this file is available. +func (file *File) WaitForAvailableUpgrade() <-chan struct{} { + return file.notifier.notifyChannel +} diff --git a/base/updater/registry.go b/base/updater/registry.go new file mode 100644 index 000000000..8deda74ed --- /dev/null +++ b/base/updater/registry.go @@ -0,0 +1,270 @@ +package updater + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + "sync" + + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/utils" +) + +const ( + onWindows = runtime.GOOS == "windows" +) + +// ResourceRegistry is a registry for managing update resources. +type ResourceRegistry struct { + sync.RWMutex + + Name string + storageDir *utils.DirStructure + tmpDir *utils.DirStructure + indexes []*Index + state *RegistryState + + resources map[string]*Resource + UpdateURLs []string + UserAgent string + MandatoryUpdates []string + AutoUnpack []string + + // Verification holds a map of VerificationOptions assigned to their + // applicable identifier path prefix. + // Use an empty string to denote the default. + // Use empty options to disable verification for a path prefix. + Verification map[string]*VerificationOptions + + // UsePreReleases signifies that pre-releases should be used when selecting a + // version. Even if false, a pre-release version will still be used if it is + // defined as the current version by an index. + UsePreReleases bool + + // DevMode specifies if a local 0.0.0 version should be always chosen, when available. + DevMode bool + + // Online specifies if resources may be downloaded if not available locally. + Online bool + + // StateNotifyFunc may be set to receive any changes to the registry state. + // The specified function may lock the state, but may not block or take a + // lot of time. + StateNotifyFunc func(*RegistryState) +} + +// AddIndex adds a new index to the resource registry. +// The order is important, as indexes added later will override the current +// release from earlier indexes. +func (reg *ResourceRegistry) AddIndex(idx Index) { + reg.Lock() + defer reg.Unlock() + + // Get channel name from path. + idx.Channel = strings.TrimSuffix( + filepath.Base(idx.Path), filepath.Ext(idx.Path), + ) + + reg.indexes = append(reg.indexes, &idx) +} + +// PreInitUpdateState sets the initial update state of the registry before initialization. +func (reg *ResourceRegistry) PreInitUpdateState(s UpdateState) error { + if reg.state != nil { + return errors.New("registry already initialized") + } + + reg.state = &RegistryState{ + Updates: s, + } + return nil +} + +// Initialize initializes a raw registry struct and makes it ready for usage. +func (reg *ResourceRegistry) Initialize(storageDir *utils.DirStructure) error { + // check if storage dir is available + err := storageDir.Ensure() + if err != nil { + return err + } + + // set default name + if reg.Name == "" { + reg.Name = "updater" + } + + // initialize private attributes + reg.storageDir = storageDir + reg.tmpDir = storageDir.ChildDir("tmp", 0o0700) + reg.resources = make(map[string]*Resource) + if reg.state == nil { + reg.state = &RegistryState{} + } + reg.state.ID = StateReady + reg.state.reg = reg + + // remove tmp dir to delete old entries + err = reg.Cleanup() + if err != nil { + log.Warningf("%s: failed to remove tmp dir: %s", reg.Name, err) + } + + // (re-)create tmp dir + err = reg.tmpDir.Ensure() + if err != nil { + log.Warningf("%s: failed to create tmp dir: %s", reg.Name, err) + } + + // Check verification options. + if reg.Verification != nil { + for prefix, opts := range reg.Verification { + // Check if verification is disable for this prefix. + if opts == nil { + continue + } + + // If enabled, a trust store is required. + if opts.TrustStore == nil { + return fmt.Errorf("verification enabled for prefix %q, but no trust store configured", prefix) + } + + // DownloadPolicy must be equal or stricter than DiskLoadPolicy. + if opts.DiskLoadPolicy < opts.DownloadPolicy { + return errors.New("verification download policy must be equal or stricter than the disk load policy") + } + + // Warn if all policies are disabled. + if opts.DownloadPolicy == SignaturePolicyDisable && + opts.DiskLoadPolicy == SignaturePolicyDisable { + log.Warningf("%s: verification enabled for prefix %q, but all policies set to disable", reg.Name, prefix) + } + } + } + + return nil +} + +// StorageDir returns the main storage dir of the resource registry. +func (reg *ResourceRegistry) StorageDir() *utils.DirStructure { + return reg.storageDir +} + +// TmpDir returns the temporary working dir of the resource registry. +func (reg *ResourceRegistry) TmpDir() *utils.DirStructure { + return reg.tmpDir +} + +// SetDevMode sets the development mode flag. +func (reg *ResourceRegistry) SetDevMode(on bool) { + reg.Lock() + defer reg.Unlock() + + reg.DevMode = on +} + +// SetUsePreReleases sets the UsePreReleases flag. +func (reg *ResourceRegistry) SetUsePreReleases(yes bool) { + reg.Lock() + defer reg.Unlock() + + reg.UsePreReleases = yes +} + +// AddResource adds a resource to the registry. Does _not_ select new version. +func (reg *ResourceRegistry) AddResource(identifier, version string, index *Index, available, currentRelease, preRelease bool) error { + reg.Lock() + defer reg.Unlock() + + err := reg.addResource(identifier, version, index, available, currentRelease, preRelease) + return err +} + +func (reg *ResourceRegistry) addResource(identifier, version string, index *Index, available, currentRelease, preRelease bool) error { + res, ok := reg.resources[identifier] + if !ok { + res = reg.newResource(identifier) + reg.resources[identifier] = res + } + res.Index = index + + return res.AddVersion(version, available, currentRelease, preRelease) +} + +// AddResources adds resources to the registry. Errors are logged, the last one is returned. Despite errors, non-failing resources are still added. Does _not_ select new versions. +func (reg *ResourceRegistry) AddResources(versions map[string]string, index *Index, available, currentRelease, preRelease bool) error { + reg.Lock() + defer reg.Unlock() + + // add versions and their flags to registry + var lastError error + for identifier, version := range versions { + lastError = reg.addResource(identifier, version, index, available, currentRelease, preRelease) + if lastError != nil { + log.Warningf("%s: failed to add resource %s: %s", reg.Name, identifier, lastError) + } + } + + return lastError +} + +// SelectVersions selects new resource versions depending on the current registry state. +func (reg *ResourceRegistry) SelectVersions() { + reg.RLock() + defer reg.RUnlock() + + for _, res := range reg.resources { + res.Lock() + res.selectVersion() + res.Unlock() + } +} + +// GetSelectedVersions returns a list of the currently selected versions. +func (reg *ResourceRegistry) GetSelectedVersions() (versions map[string]string) { + reg.RLock() + defer reg.RUnlock() + + for _, res := range reg.resources { + res.Lock() + versions[res.Identifier] = res.SelectedVersion.VersionNumber + res.Unlock() + } + + return +} + +// Purge deletes old updates, retaining a certain amount, specified by the keep +// parameter. Will at least keep 2 updates per resource. +func (reg *ResourceRegistry) Purge(keep int) { + reg.RLock() + defer reg.RUnlock() + + for _, res := range reg.resources { + res.Purge(keep) + } +} + +// ResetResources removes all resources from the registry. +func (reg *ResourceRegistry) ResetResources() { + reg.Lock() + defer reg.Unlock() + + reg.resources = make(map[string]*Resource) +} + +// ResetIndexes removes all indexes from the registry. +func (reg *ResourceRegistry) ResetIndexes() { + reg.Lock() + defer reg.Unlock() + + reg.indexes = make([]*Index, 0, len(reg.indexes)) +} + +// Cleanup removes temporary files. +func (reg *ResourceRegistry) Cleanup() error { + // delete download tmp dir + return os.RemoveAll(reg.tmpDir.Path) +} diff --git a/base/updater/registry_test.go b/base/updater/registry_test.go new file mode 100644 index 000000000..a8978f68c --- /dev/null +++ b/base/updater/registry_test.go @@ -0,0 +1,35 @@ +package updater + +import ( + "os" + "testing" + + "github.com/safing/portmaster/base/utils" +) + +var registry *ResourceRegistry + +func TestMain(m *testing.M) { + // setup + tmpDir, err := os.MkdirTemp("", "ci-portmaster-") + if err != nil { + panic(err) + } + registry = &ResourceRegistry{ + UsePreReleases: true, + DevMode: true, + Online: true, + } + err = registry.Initialize(utils.NewDirStructure(tmpDir, 0o0777)) + if err != nil { + panic(err) + } + + // run + // call flag.Parse() here if TestMain uses flags + ret := m.Run() + + // teardown + _ = os.RemoveAll(tmpDir) + os.Exit(ret) +} diff --git a/base/updater/resource.go b/base/updater/resource.go new file mode 100644 index 000000000..325f70cc5 --- /dev/null +++ b/base/updater/resource.go @@ -0,0 +1,582 @@ +package updater + +import ( + "errors" + "io/fs" + "os" + "path/filepath" + "sort" + "strings" + "sync" + + semver "github.com/hashicorp/go-version" + + "github.com/safing/jess/filesig" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/utils" +) + +var devVersion *semver.Version + +func init() { + var err error + devVersion, err = semver.NewVersion("0") + if err != nil { + panic(err) + } +} + +// Resource represents a resource (via an identifier) and multiple file versions. +type Resource struct { + sync.Mutex + registry *ResourceRegistry + notifier *notifier + + // Identifier is the unique identifier for that resource. + // It forms a file path using a forward-slash as the + // path separator. + Identifier string + + // Versions holds all available resource versions. + Versions []*ResourceVersion + + // ActiveVersion is the last version of the resource + // that someone requested using GetFile(). + ActiveVersion *ResourceVersion + + // SelectedVersion is newest, selectable version of + // that resource that is available. A version + // is selectable if it's not blacklisted by the user. + // Note that it's not guaranteed that the selected version + // is available locally. In that case, GetFile will attempt + // to download the latest version from the updates servers + // specified in the resource registry. + SelectedVersion *ResourceVersion + + // VerificationOptions holds the verification options for this resource. + VerificationOptions *VerificationOptions + + // Index holds a reference to the index this resource was last defined in. + // Will be nil if resource was only found on disk. + Index *Index +} + +// ResourceVersion represents a single version of a resource. +type ResourceVersion struct { + resource *Resource + + // VersionNumber is the string representation of the resource + // version. + VersionNumber string + semVer *semver.Version + + // Available indicates if this version is available locally. + Available bool + + // SigAvailable indicates if the signature of this version is available locally. + SigAvailable bool + + // CurrentRelease indicates that this is the current release that should be + // selected, if possible. + CurrentRelease bool + + // PreRelease indicates that this version is pre-release. + PreRelease bool + + // Blacklisted may be set to true if this version should + // be skipped and not used. This is useful if the version + // is known to be broken. + Blacklisted bool +} + +func (rv *ResourceVersion) String() string { + return rv.VersionNumber +} + +// SemVer returns the semantic version of the resource. +func (rv *ResourceVersion) SemVer() *semver.Version { + return rv.semVer +} + +// EqualsVersion normalizes the given version and checks equality with semver. +func (rv *ResourceVersion) EqualsVersion(version string) bool { + cmpSemVer, err := semver.NewVersion(version) + if err != nil { + return false + } + + return rv.semVer.Equal(cmpSemVer) +} + +// isSelectable returns true if the version represented by rv is selectable. +// A version is selectable if it's not blacklisted and either already locally +// available or ready to be downloaded. +func (rv *ResourceVersion) isSelectable() bool { + switch { + case rv.Blacklisted: + // Should not be used. + return false + case rv.Available: + // Is available locally, use! + return true + case !rv.resource.registry.Online: + // Cannot download, because registry is set to offline. + return false + case rv.resource.Index == nil: + // Cannot download, because resource is not part of an index. + return false + case !rv.resource.Index.AutoDownload: + // Cannot download, because index may not automatically download. + return false + default: + // Is not available locally, but we are allowed to download it on request! + return true + } +} + +// isBetaVersionNumber checks if rv is marked as a beta version by checking +// the version string. It does not honor the BetaRelease field of rv! +func (rv *ResourceVersion) isBetaVersionNumber() bool { //nolint:unused + // "b" suffix check if for backwards compatibility + // new versions should use the pre-release suffix as + // declared by https://semver.org + // i.e. 1.2.3-beta + switch rv.semVer.Prerelease() { + case "b", "beta": + return true + default: + return false + } +} + +// Export makes a copy of the resource with only the exposed information. +// Attributes are copied and safe to access. +// Any ResourceVersion must not be modified. +func (res *Resource) Export() *Resource { + res.Lock() + defer res.Unlock() + + // Copy attibutes. + export := &Resource{ + Identifier: res.Identifier, + Versions: make([]*ResourceVersion, len(res.Versions)), + ActiveVersion: res.ActiveVersion, + SelectedVersion: res.SelectedVersion, + } + // Copy Versions slice. + copy(export.Versions, res.Versions) + + return export +} + +// Len is the number of elements in the collection. +// It implements sort.Interface for ResourceVersion. +func (res *Resource) Len() int { + return len(res.Versions) +} + +// Less reports whether the element with index i should +// sort before the element with index j. +// It implements sort.Interface for ResourceVersions. +func (res *Resource) Less(i, j int) bool { + return res.Versions[i].semVer.GreaterThan(res.Versions[j].semVer) +} + +// Swap swaps the elements with indexes i and j. +// It implements sort.Interface for ResourceVersions. +func (res *Resource) Swap(i, j int) { + res.Versions[i], res.Versions[j] = res.Versions[j], res.Versions[i] +} + +// available returns whether any version of the resource is available. +func (res *Resource) available() bool { + for _, rv := range res.Versions { + if rv.Available { + return true + } + } + return false +} + +// inUse returns true if the resource is currently in use. +func (res *Resource) inUse() bool { + return res.ActiveVersion != nil +} + +// AnyVersionAvailable returns true if any version of +// res is locally available. +func (res *Resource) AnyVersionAvailable() bool { + res.Lock() + defer res.Unlock() + + return res.available() +} + +func (reg *ResourceRegistry) newResource(identifier string) *Resource { + return &Resource{ + registry: reg, + Identifier: identifier, + Versions: make([]*ResourceVersion, 0, 1), + VerificationOptions: reg.GetVerificationOptions(identifier), + } +} + +// AddVersion adds a resource version to a resource. +func (res *Resource) AddVersion(version string, available, currentRelease, preRelease bool) error { + res.Lock() + defer res.Unlock() + + // reset current release flags + if currentRelease { + for _, rv := range res.Versions { + rv.CurrentRelease = false + } + } + + var rv *ResourceVersion + // check for existing version + for _, possibleMatch := range res.Versions { + if possibleMatch.VersionNumber == version { + rv = possibleMatch + break + } + } + + // create new version if none found + if rv == nil { + // parse to semver + sv, err := semver.NewVersion(version) + if err != nil { + return err + } + + rv = &ResourceVersion{ + resource: res, + VersionNumber: sv.String(), // Use normalized version. + semVer: sv, + } + res.Versions = append(res.Versions, rv) + } + + // set flags + if available { + rv.Available = true + + // If available and signatures are enabled for this resource, check if the + // signature is available. + if res.VerificationOptions != nil && utils.PathExists(rv.storageSigPath()) { + rv.SigAvailable = true + } + } + if currentRelease { + rv.CurrentRelease = true + } + if preRelease || rv.semVer.Prerelease() != "" { + rv.PreRelease = true + } + + return nil +} + +// GetFile returns the selected version as a *File. +func (res *Resource) GetFile() *File { + res.Lock() + defer res.Unlock() + + // check for notifier + if res.notifier == nil { + // create new notifier + res.notifier = newNotifier() + } + + // check if version is selected + if res.SelectedVersion == nil { + res.selectVersion() + } + + // create file + return &File{ + resource: res, + version: res.SelectedVersion, + notifier: res.notifier, + versionedPath: res.SelectedVersion.versionedPath(), + storagePath: res.SelectedVersion.storagePath(), + } +} + +//nolint:gocognit // function already kept as simple as possible +func (res *Resource) selectVersion() { + sort.Sort(res) + + // export after we finish + var fallback bool + defer func() { + if fallback { + log.Tracef("updater: selected version %s (as fallback) for resource %s", res.SelectedVersion, res.Identifier) + } else { + log.Debugf("updater: selected version %s for resource %s", res.SelectedVersion, res.Identifier) + } + + if res.inUse() && + res.SelectedVersion != res.ActiveVersion && // new selected version does not match previously selected version + res.notifier != nil { + + res.notifier.markAsUpgradeable() + res.notifier = nil + + log.Debugf("updater: active version of %s is %s, update available", res.Identifier, res.ActiveVersion.VersionNumber) + } + }() + + if len(res.Versions) == 0 { + // TODO: find better way to deal with an empty version slice (which should not happen) + res.SelectedVersion = nil + return + } + + // Target selection + + // 1) Dev release if dev mode is active and ignore blacklisting + if res.registry.DevMode { + // Get last version, as this will be v0.0.0, if available. + rv := res.Versions[len(res.Versions)-1] + // Check if it's v0.0.0. + if rv.semVer.Equal(devVersion) && rv.Available { + res.SelectedVersion = rv + return + } + } + + // 2) Find the current release. This may be also be a pre-release. + for _, rv := range res.Versions { + if rv.CurrentRelease { + if rv.isSelectable() { + res.SelectedVersion = rv + return + } + // There can only be once current release, + // so we can abort after finding one. + break + } + } + + // 3) If UsePreReleases is set, find any newest version. + if res.registry.UsePreReleases { + for _, rv := range res.Versions { + if rv.isSelectable() { + res.SelectedVersion = rv + return + } + } + } + + // 4) Find the newest stable version. + for _, rv := range res.Versions { + if !rv.PreRelease && rv.isSelectable() { + res.SelectedVersion = rv + return + } + } + + // 5) Default to newest. + res.SelectedVersion = res.Versions[0] + fallback = true +} + +// Blacklist blacklists the specified version and selects a new version. +func (res *Resource) Blacklist(version string) error { + res.Lock() + defer res.Unlock() + + // count available and valid versions + valid := 0 + for _, rv := range res.Versions { + if rv.semVer.Equal(devVersion) { + continue // ignore dev versions + } + if !rv.Blacklisted { + valid++ + } + } + if valid <= 1 { + return errors.New("cannot blacklist last version") // last one, cannot blacklist! + } + + // find version and blacklist + for _, rv := range res.Versions { + if rv.VersionNumber == version { + // blacklist and update + rv.Blacklisted = true + res.selectVersion() + return nil + } + } + + return errors.New("could not find version") +} + +// Purge deletes old updates, retaining a certain amount, specified by +// the keep parameter. Purge will always keep at least 2 versions so +// specifying a smaller keep value will have no effect. +func (res *Resource) Purge(keepExtra int) { //nolint:gocognit + res.Lock() + defer res.Unlock() + + // If there is any blacklisted version within the resource, pause purging. + // In this case we may need extra available versions beyond what would be + // available after purging. + for _, rv := range res.Versions { + if rv.Blacklisted { + log.Debugf( + "%s: pausing purging of resource %s, as it contains blacklisted items", + res.registry.Name, + rv.resource.Identifier, + ) + return + } + } + + // Safeguard the amount of extra version to keep. + if keepExtra < 2 { + keepExtra = 2 + } + + // Search for purge boundary. + var purgeBoundary int + var skippedActiveVersion bool + var skippedSelectedVersion bool + var skippedStableVersion bool +boundarySearch: + for i, rv := range res.Versions { + // Check if required versions are already skipped. + switch { + case !skippedActiveVersion && res.ActiveVersion != nil: + // Skip versions until the active version, if it's set. + case !skippedSelectedVersion && res.SelectedVersion != nil: + // Skip versions until the selected version, if it's set. + case !skippedStableVersion: + // Skip versions until the stable version. + default: + // All required version skipped, set purge boundary. + purgeBoundary = i + keepExtra + break boundarySearch + } + + // Check if current instance is a required version. + if rv == res.ActiveVersion { + skippedActiveVersion = true + } + if rv == res.SelectedVersion { + skippedSelectedVersion = true + } + if !rv.PreRelease { + skippedStableVersion = true + } + } + + // Check if there is anything to purge at all. + if purgeBoundary <= keepExtra || purgeBoundary >= len(res.Versions) { + return + } + + // Purge everything beyond the purge boundary. + for _, rv := range res.Versions[purgeBoundary:] { + // Only remove if resource file is actually available. + if !rv.Available { + continue + } + + // Remove resource file. + storagePath := rv.storagePath() + err := os.Remove(storagePath) + if err != nil { + if !errors.Is(err, fs.ErrNotExist) { + log.Warningf("%s: failed to purge resource %s v%s: %s", res.registry.Name, rv.resource.Identifier, rv.VersionNumber, err) + } + } else { + log.Tracef("%s: purged resource %s v%s", res.registry.Name, rv.resource.Identifier, rv.VersionNumber) + } + + // Remove resource signature file. + err = os.Remove(rv.storageSigPath()) + if err != nil { + if !errors.Is(err, fs.ErrNotExist) { + log.Warningf("%s: failed to purge resource signature %s v%s: %s", res.registry.Name, rv.resource.Identifier, rv.VersionNumber, err) + } + } else { + log.Tracef("%s: purged resource signature %s v%s", res.registry.Name, rv.resource.Identifier, rv.VersionNumber) + } + + // Remove unpacked version of resource. + ext := filepath.Ext(storagePath) + if ext == "" { + // Nothing to do if file does not have an extension. + continue + } + unpackedPath := strings.TrimSuffix(storagePath, ext) + + // Remove if it exists, or an error occurs on access. + _, err = os.Stat(unpackedPath) + if err == nil || !errors.Is(err, fs.ErrNotExist) { + err = os.Remove(unpackedPath) + if err != nil { + log.Warningf("%s: failed to purge unpacked resource %s v%s: %s", res.registry.Name, rv.resource.Identifier, rv.VersionNumber, err) + } else { + log.Tracef("%s: purged unpacked resource %s v%s", res.registry.Name, rv.resource.Identifier, rv.VersionNumber) + } + } + } + + // remove entries of deleted files + res.Versions = res.Versions[purgeBoundary:] +} + +// SigningMetadata returns the metadata to be included in signatures. +func (rv *ResourceVersion) SigningMetadata() map[string]string { + return map[string]string{ + "id": rv.resource.Identifier, + "version": rv.VersionNumber, + } +} + +// GetFile returns the version as a *File. +// It locks the resource for doing so. +func (rv *ResourceVersion) GetFile() *File { + rv.resource.Lock() + defer rv.resource.Unlock() + + // check for notifier + if rv.resource.notifier == nil { + // create new notifier + rv.resource.notifier = newNotifier() + } + + // create file + return &File{ + resource: rv.resource, + version: rv, + notifier: rv.resource.notifier, + versionedPath: rv.versionedPath(), + storagePath: rv.storagePath(), + } +} + +// versionedPath returns the versioned identifier. +func (rv *ResourceVersion) versionedPath() string { + return GetVersionedPath(rv.resource.Identifier, rv.VersionNumber) +} + +// versionedSigPath returns the versioned identifier of the file signature. +func (rv *ResourceVersion) versionedSigPath() string { + return GetVersionedPath(rv.resource.Identifier, rv.VersionNumber) + filesig.Extension +} + +// storagePath returns the absolute storage path. +func (rv *ResourceVersion) storagePath() string { + return filepath.Join(rv.resource.registry.storageDir.Path, filepath.FromSlash(rv.versionedPath())) +} + +// storageSigPath returns the absolute storage path of the file signature. +func (rv *ResourceVersion) storageSigPath() string { + return rv.storagePath() + filesig.Extension +} diff --git a/base/updater/resource_test.go b/base/updater/resource_test.go new file mode 100644 index 000000000..ceb51e9f9 --- /dev/null +++ b/base/updater/resource_test.go @@ -0,0 +1,119 @@ +package updater + +import ( + "fmt" + "testing" + + semver "github.com/hashicorp/go-version" + "github.com/stretchr/testify/assert" +) + +func TestVersionSelection(t *testing.T) { + t.Parallel() + + res := registry.newResource("test/a") + + err := res.AddVersion("1.2.2", true, false, false) + if err != nil { + t.Fatal(err) + } + err = res.AddVersion("1.2.3", true, false, false) + if err != nil { + t.Fatal(err) + } + err = res.AddVersion("1.2.4-beta", true, false, false) + if err != nil { + t.Fatal(err) + } + err = res.AddVersion("1.2.4-staging", true, false, false) + if err != nil { + t.Fatal(err) + } + err = res.AddVersion("1.2.5", false, false, false) + if err != nil { + t.Fatal(err) + } + err = res.AddVersion("1.2.6-beta", false, false, false) + if err != nil { + t.Fatal(err) + } + err = res.AddVersion("0", true, false, false) + if err != nil { + t.Fatal(err) + } + + registry.UsePreReleases = true + registry.DevMode = true + registry.Online = true + res.Index = &Index{AutoDownload: true} + + res.selectVersion() + if res.SelectedVersion.VersionNumber != "0.0.0" { + t.Errorf("selected version should be 0.0.0, not %s", res.SelectedVersion.VersionNumber) + } + + registry.DevMode = false + res.selectVersion() + if res.SelectedVersion.VersionNumber != "1.2.6-beta" { + t.Errorf("selected version should be 1.2.6-beta, not %s", res.SelectedVersion.VersionNumber) + } + + registry.UsePreReleases = false + res.selectVersion() + if res.SelectedVersion.VersionNumber != "1.2.5" { + t.Errorf("selected version should be 1.2.5, not %s", res.SelectedVersion.VersionNumber) + } + + registry.Online = false + res.selectVersion() + if res.SelectedVersion.VersionNumber != "1.2.3" { + t.Errorf("selected version should be 1.2.3, not %s", res.SelectedVersion.VersionNumber) + } + + f123 := res.GetFile() + f123.markActiveWithLocking() + + err = res.Blacklist("1.2.3") + if err != nil { + t.Fatal(err) + } + if res.SelectedVersion.VersionNumber != "1.2.2" { + t.Errorf("selected version should be 1.2.2, not %s", res.SelectedVersion.VersionNumber) + } + + if !f123.UpgradeAvailable() { + t.Error("upgrade should be available (flag)") + } + select { + case <-f123.WaitForAvailableUpgrade(): + default: + t.Error("upgrade should be available (chan)") + } + + t.Logf("resource: %+v", res) + for _, rv := range res.Versions { + t.Logf("version %s: %+v", rv.VersionNumber, rv) + } +} + +func TestVersionParsing(t *testing.T) { + t.Parallel() + + assert.Equal(t, "1.2.3", parseVersion("1.2.3")) + assert.Equal(t, "1.2.0", parseVersion("1.2.0")) + assert.Equal(t, "0.2.0", parseVersion("0.2.0")) + assert.Equal(t, "0.0.0", parseVersion("0")) + assert.Equal(t, "1.2.3-b", parseVersion("1.2.3-b")) + assert.Equal(t, "1.2.3-b", parseVersion("1.2.3b")) + assert.Equal(t, "1.2.3-beta", parseVersion("1.2.3-beta")) + assert.Equal(t, "1.2.3-beta", parseVersion("1.2.3beta")) + assert.Equal(t, "1.2.3", parseVersion("01.02.03")) +} + +func parseVersion(v string) string { + sv, err := semver.NewVersion(v) + if err != nil { + return fmt.Sprintf("failed to parse version: %s", err) + } + return sv.String() +} diff --git a/base/updater/signing.go b/base/updater/signing.go new file mode 100644 index 000000000..cffd5cbed --- /dev/null +++ b/base/updater/signing.go @@ -0,0 +1,49 @@ +package updater + +import ( + "strings" + + "github.com/safing/jess" +) + +// VerificationOptions holds options for verification of files. +type VerificationOptions struct { + TrustStore jess.TrustStore + DownloadPolicy SignaturePolicy + DiskLoadPolicy SignaturePolicy +} + +// GetVerificationOptions returns the verification options for the given identifier. +func (reg *ResourceRegistry) GetVerificationOptions(identifier string) *VerificationOptions { + if reg.Verification == nil { + return nil + } + + var ( + longestPrefix = -1 + bestMatch *VerificationOptions + ) + for prefix, opts := range reg.Verification { + if len(prefix) > longestPrefix && strings.HasPrefix(identifier, prefix) { + longestPrefix = len(prefix) + bestMatch = opts + } + } + + return bestMatch +} + +// SignaturePolicy defines behavior in case of errors. +type SignaturePolicy uint8 + +// Signature Policies. +const ( + // SignaturePolicyRequire fails on any error. + SignaturePolicyRequire = iota + + // SignaturePolicyWarn only warns on errors. + SignaturePolicyWarn + + // SignaturePolicyDisable only downloads signatures, but does not verify them. + SignaturePolicyDisable +) diff --git a/base/updater/state.go b/base/updater/state.go new file mode 100644 index 000000000..20c27f465 --- /dev/null +++ b/base/updater/state.go @@ -0,0 +1,180 @@ +package updater + +import ( + "sort" + "sync" + "time" + + "github.com/safing/portmaster/base/utils" +) + +// Registry States. +const ( + StateReady = "ready" // Default idle state. + StateChecking = "checking" // Downloading indexes. + StateDownloading = "downloading" // Downloading updates. + StateFetching = "fetching" // Fetching a single file. +) + +// RegistryState describes the registry state. +type RegistryState struct { + sync.Mutex + reg *ResourceRegistry + + // ID holds the ID of the state the registry is currently in. + ID string + + // Details holds further information about the current state. + Details any + + // Updates holds generic information about the current status of pending + // and recently downloaded updates. + Updates UpdateState + + // operationLock locks the operation of any state changing operation. + // This is separate from the registry lock, which locks access to the + // registry struct. + operationLock sync.Mutex +} + +// StateDownloadingDetails holds details of the downloading state. +type StateDownloadingDetails struct { + // Resources holds the resource IDs that are being downloaded. + Resources []string + + // FinishedUpTo holds the index of Resources that is currently being + // downloaded. Previous resources have finished downloading. + FinishedUpTo int +} + +// UpdateState holds generic information about the current status of pending +// and recently downloaded updates. +type UpdateState struct { + // LastCheckAt holds the time of the last update check. + LastCheckAt *time.Time + // LastCheckError holds the error of the last check. + LastCheckError error + // PendingDownload holds the resources that are pending download. + PendingDownload []string + + // LastDownloadAt holds the time when resources were downloaded the last time. + LastDownloadAt *time.Time + // LastDownloadError holds the error of the last download. + LastDownloadError error + // LastDownload holds the resources that we downloaded the last time updates + // were downloaded. + LastDownload []string + + // LastSuccessAt holds the time of the last successful update (check). + LastSuccessAt *time.Time +} + +// GetState returns the current registry state. +// The returned data must not be modified. +func (reg *ResourceRegistry) GetState() RegistryState { + reg.state.Lock() + defer reg.state.Unlock() + + return RegistryState{ + ID: reg.state.ID, + Details: reg.state.Details, + Updates: reg.state.Updates, + } +} + +// StartOperation starts an operation. +func (s *RegistryState) StartOperation(id string) bool { + defer s.notify() + + s.operationLock.Lock() + + s.Lock() + defer s.Unlock() + + s.ID = id + return true +} + +// UpdateOperationDetails updates the details of an operation. +// The supplied struct should be a copy and must not be changed after calling +// this function. +func (s *RegistryState) UpdateOperationDetails(details any) { + defer s.notify() + + s.Lock() + defer s.Unlock() + + s.Details = details +} + +// EndOperation ends an operation. +func (s *RegistryState) EndOperation() { + defer s.notify() + defer s.operationLock.Unlock() + + s.Lock() + defer s.Unlock() + + s.ID = StateReady + s.Details = nil +} + +// ReportUpdateCheck reports an update check to the registry state. +func (s *RegistryState) ReportUpdateCheck(pendingDownload []string, failed error) { + defer s.notify() + + sort.Strings(pendingDownload) + + s.Lock() + defer s.Unlock() + + now := time.Now() + s.Updates.LastCheckAt = &now + s.Updates.LastCheckError = failed + s.Updates.PendingDownload = pendingDownload + + if failed == nil { + s.Updates.LastSuccessAt = &now + } +} + +// ReportDownloads reports downloaded updates to the registry state. +func (s *RegistryState) ReportDownloads(downloaded []string, failed error) { + defer s.notify() + + sort.Strings(downloaded) + + s.Lock() + defer s.Unlock() + + now := time.Now() + s.Updates.LastDownloadAt = &now + s.Updates.LastDownloadError = failed + s.Updates.LastDownload = downloaded + + // Remove downloaded resources from the pending list. + if len(s.Updates.PendingDownload) > 0 { + newPendingDownload := make([]string, 0, len(s.Updates.PendingDownload)) + for _, pending := range s.Updates.PendingDownload { + if !utils.StringInSlice(downloaded, pending) { + newPendingDownload = append(newPendingDownload, pending) + } + } + s.Updates.PendingDownload = newPendingDownload + } + + if failed == nil { + s.Updates.LastSuccessAt = &now + } +} + +func (s *RegistryState) notify() { + switch { + case s.reg == nil: + return + case s.reg.StateNotifyFunc == nil: + return + } + + s.reg.StateNotifyFunc(s) +} diff --git a/base/updater/storage.go b/base/updater/storage.go new file mode 100644 index 000000000..cd05bdbda --- /dev/null +++ b/base/updater/storage.go @@ -0,0 +1,272 @@ +package updater + +import ( + "context" + "errors" + "fmt" + "io/fs" + "net/http" + "os" + "path/filepath" + "strings" + + "github.com/safing/jess/filesig" + "github.com/safing/jess/lhash" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/utils" +) + +// ScanStorage scans root within the storage dir and adds found +// resources to the registry. If an error occurred, it is logged +// and the last error is returned. Everything that was found +// despite errors is added to the registry anyway. Leave root +// empty to scan the full storage dir. +func (reg *ResourceRegistry) ScanStorage(root string) error { + var lastError error + + // prep root + if root == "" { + root = reg.storageDir.Path + } else { + var err error + root, err = filepath.Abs(root) + if err != nil { + return err + } + if !strings.HasPrefix(root, reg.storageDir.Path) { + return errors.New("supplied scan root path not within storage") + } + } + + // walk fs + _ = filepath.Walk(root, func(path string, info os.FileInfo, err error) error { + // skip tmp dir (including errors trying to read it) + if strings.HasPrefix(path, reg.tmpDir.Path) { + return filepath.SkipDir + } + + // handle walker error + if err != nil { + lastError = fmt.Errorf("%s: could not read %s: %w", reg.Name, path, err) + log.Warning(lastError.Error()) + return nil + } + + // Ignore file signatures. + if strings.HasSuffix(path, filesig.Extension) { + return nil + } + + // get relative path to storage + relativePath, err := filepath.Rel(reg.storageDir.Path, path) + if err != nil { + lastError = fmt.Errorf("%s: could not get relative path of %s: %w", reg.Name, path, err) + log.Warning(lastError.Error()) + return nil + } + + // convert to identifier and version + relativePath = filepath.ToSlash(relativePath) + identifier, version, ok := GetIdentifierAndVersion(relativePath) + if !ok { + // file does not conform to format + return nil + } + + // fully ignore directories that also have an identifier - these will be unpacked resources + if info.IsDir() { + return filepath.SkipDir + } + + // save + err = reg.AddResource(identifier, version, nil, true, false, false) + if err != nil { + lastError = fmt.Errorf("%s: could not get add resource %s v%s: %w", reg.Name, identifier, version, err) + log.Warning(lastError.Error()) + } + return nil + }) + + return lastError +} + +// LoadIndexes loads the current release indexes from disk +// or will fetch a new version if not available and the +// registry is marked as online. +func (reg *ResourceRegistry) LoadIndexes(ctx context.Context) error { + var firstErr error + client := &http.Client{} + for _, idx := range reg.getIndexes() { + err := reg.loadIndexFile(idx) + if err == nil { + log.Debugf("%s: loaded index %s", reg.Name, idx.Path) + } else if reg.Online { + // try to download the index file if a local disk version + // does not exist or we don't have permission to read it. + if errors.Is(err, fs.ErrNotExist) || errors.Is(err, fs.ErrPermission) { + err = reg.downloadIndex(ctx, client, idx) + } + } + + if err != nil && firstErr == nil { + firstErr = err + } + } + + return firstErr +} + +// getIndexes returns a copy of the index. +// The indexes itself are references. +func (reg *ResourceRegistry) getIndexes() []*Index { + reg.RLock() + defer reg.RUnlock() + + indexes := make([]*Index, len(reg.indexes)) + copy(indexes, reg.indexes) + return indexes +} + +func (reg *ResourceRegistry) loadIndexFile(idx *Index) error { + indexPath := filepath.Join(reg.storageDir.Path, filepath.FromSlash(idx.Path)) + indexData, err := os.ReadFile(indexPath) + if err != nil { + return fmt.Errorf("failed to read index file %s: %w", idx.Path, err) + } + + // Verify signature, if enabled. + if verifOpts := reg.GetVerificationOptions(idx.Path); verifOpts != nil { + // Load and check signature. + verifiedHash, _, err := reg.loadAndVerifySigFile(verifOpts, indexPath+filesig.Extension) + if err != nil { + switch verifOpts.DiskLoadPolicy { + case SignaturePolicyRequire: + return fmt.Errorf("failed to verify signature of index %s: %w", idx.Path, err) + case SignaturePolicyWarn: + log.Warningf("%s: failed to verify signature of index %s: %s", reg.Name, idx.Path, err) + case SignaturePolicyDisable: + log.Debugf("%s: failed to verify signature of index %s: %s", reg.Name, idx.Path, err) + } + } + + // Check if signature checksum matches the index data. + if err == nil && !verifiedHash.Matches(indexData) { + switch verifOpts.DiskLoadPolicy { + case SignaturePolicyRequire: + return fmt.Errorf("index file %s does not match signature", idx.Path) + case SignaturePolicyWarn: + log.Warningf("%s: index file %s does not match signature", reg.Name, idx.Path) + case SignaturePolicyDisable: + log.Debugf("%s: index file %s does not match signature", reg.Name, idx.Path) + } + } + } + + // Parse the index file. + indexFile, err := ParseIndexFile(indexData, idx.Channel, idx.LastRelease) + if err != nil { + return fmt.Errorf("failed to parse index file %s: %w", idx.Path, err) + } + + // Update last seen release. + idx.LastRelease = indexFile.Published + + // Warn if there aren't any releases in the index. + if len(indexFile.Releases) == 0 { + log.Debugf("%s: index %s has no releases", reg.Name, idx.Path) + return nil + } + + // Add index releases to available resources. + err = reg.AddResources(indexFile.Releases, idx, false, true, idx.PreRelease) + if err != nil { + log.Warningf("%s: failed to add resource: %s", reg.Name, err) + } + return nil +} + +func (reg *ResourceRegistry) loadAndVerifySigFile(verifOpts *VerificationOptions, sigFilePath string) (*lhash.LabeledHash, []byte, error) { + // Load signature file. + sigFileData, err := os.ReadFile(sigFilePath) + if err != nil { + return nil, nil, fmt.Errorf("failed to read signature file: %w", err) + } + + // Extract all signatures. + sigs, err := filesig.ParseSigFile(sigFileData) + switch { + case len(sigs) == 0 && err != nil: + return nil, nil, fmt.Errorf("failed to parse signature file: %w", err) + case len(sigs) == 0: + return nil, nil, errors.New("no signatures found in signature file") + case err != nil: + return nil, nil, fmt.Errorf("failed to parse signature file: %w", err) + } + + // Verify all signatures. + var verifiedHash *lhash.LabeledHash + for _, sig := range sigs { + fd, err := filesig.VerifyFileData( + sig, + nil, + verifOpts.TrustStore, + ) + if err != nil { + return nil, sigFileData, err + } + + // Save or check verified hash. + if verifiedHash == nil { + verifiedHash = fd.FileHash() + } else if !fd.FileHash().Equal(verifiedHash) { + // Return an error if two valid hashes mismatch. + // For simplicity, all hash algorithms must be the same for now. + return nil, sigFileData, errors.New("file hashes from different signatures do not match") + } + } + + return verifiedHash, sigFileData, nil +} + +// CreateSymlinks creates a directory structure with unversioned symlinks to the given updates list. +func (reg *ResourceRegistry) CreateSymlinks(symlinkRoot *utils.DirStructure) error { + err := os.RemoveAll(symlinkRoot.Path) + if err != nil { + return fmt.Errorf("failed to wipe symlink root: %w", err) + } + + err = symlinkRoot.Ensure() + if err != nil { + return fmt.Errorf("failed to create symlink root: %w", err) + } + + reg.RLock() + defer reg.RUnlock() + + for _, res := range reg.resources { + if res.SelectedVersion == nil { + return fmt.Errorf("no selected version available for %s", res.Identifier) + } + + targetPath := res.SelectedVersion.storagePath() + linkPath := filepath.Join(symlinkRoot.Path, filepath.FromSlash(res.Identifier)) + linkPathDir := filepath.Dir(linkPath) + + err = symlinkRoot.EnsureAbsPath(linkPathDir) + if err != nil { + return fmt.Errorf("failed to create dir for link: %w", err) + } + + relativeTargetPath, err := filepath.Rel(linkPathDir, targetPath) + if err != nil { + return fmt.Errorf("failed to get relative target path: %w", err) + } + + err = os.Symlink(relativeTargetPath, linkPath) + if err != nil { + return fmt.Errorf("failed to link %s: %w", res.Identifier, err) + } + } + + return nil +} diff --git a/base/updater/storage_test.go b/base/updater/storage_test.go new file mode 100644 index 000000000..2e4122fa3 --- /dev/null +++ b/base/updater/storage_test.go @@ -0,0 +1,68 @@ +package updater + +/* +func testLoadLatestScope(t *testing.T, basePath, filePath, expectedIdentifier, expectedVersion string) { + fullPath := filepath.Join(basePath, filePath) + + // create dir + dirPath := filepath.Dir(fullPath) + err := os.MkdirAll(dirPath, 0755) + if err != nil { + t.Fatalf("could not create test dir: %s\n", err) + return + } + + // touch file + err = os.WriteFile(fullPath, []byte{}, 0644) + if err != nil { + t.Fatalf("could not create test file: %s\n", err) + return + } + + // run loadLatestScope + latest, err := ScanForLatest(basePath, true) + if err != nil { + t.Errorf("could not update latest: %s\n", err) + return + } + for key, val := range latest { + localUpdates[key] = val + } + + // test result + version, ok := localUpdates[expectedIdentifier] + if !ok { + t.Errorf("identifier %s not in map", expectedIdentifier) + t.Errorf("current map: %v", localUpdates) + } + if version != expectedVersion { + t.Errorf("unexpected version for %s: %s", filePath, version) + } +} + +func TestLoadLatestScope(t *testing.T) { + + updatesLock.Lock() + defer updatesLock.Unlock() + + tmpDir, err := os.MkdirTemp("", "testing_") + if err != nil { + t.Fatalf("could not create test dir: %s\n", err) + return + } + defer os.RemoveAll(tmpDir) + + testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-2-3.zip", "all/ui/assets.zip", "1.2.3") + testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-2-4b.zip", "all/ui/assets.zip", "1.2.4b") + testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-2-5.zip", "all/ui/assets.zip", "1.2.5") + testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-3-4.zip", "all/ui/assets.zip", "1.3.4") + testLoadLatestScope(t, tmpDir, "all/ui/assets_v2-3-4.zip", "all/ui/assets.zip", "2.3.4") + testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-2-3.zip", "all/ui/assets.zip", "2.3.4") + testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-2-4.zip", "all/ui/assets.zip", "2.3.4") + testLoadLatestScope(t, tmpDir, "all/ui/assets_v1-3-4.zip", "all/ui/assets.zip", "2.3.4") + testLoadLatestScope(t, tmpDir, "os_platform/portmaster/portmaster_v1-2-3", "os_platform/portmaster/portmaster", "1.2.3") + testLoadLatestScope(t, tmpDir, "os_platform/portmaster/portmaster_v2-1-1", "os_platform/portmaster/portmaster", "2.1.1") + testLoadLatestScope(t, tmpDir, "os_platform/portmaster/portmaster_v1-2-3", "os_platform/portmaster/portmaster", "2.1.1") + +} +*/ diff --git a/base/updater/unpacking.go b/base/updater/unpacking.go new file mode 100644 index 000000000..75d489212 --- /dev/null +++ b/base/updater/unpacking.go @@ -0,0 +1,195 @@ +package updater + +import ( + "archive/zip" + "compress/gzip" + "errors" + "fmt" + "io" + "io/fs" + "os" + "path" + "path/filepath" + "strings" + + "github.com/hashicorp/go-multierror" + + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/utils" +) + +// MaxUnpackSize specifies the maximum size that will be unpacked. +const MaxUnpackSize = 1000000000 // 1GB + +// UnpackGZIP unpacks a GZIP compressed reader r +// and returns a new reader. It's suitable to be +// used with registry.GetPackedFile. +func UnpackGZIP(r io.Reader) (io.Reader, error) { + return gzip.NewReader(r) +} + +// UnpackResources unpacks all resources defined in the AutoUnpack list. +func (reg *ResourceRegistry) UnpackResources() error { + reg.RLock() + defer reg.RUnlock() + + var multierr *multierror.Error + for _, res := range reg.resources { + if utils.StringInSlice(reg.AutoUnpack, res.Identifier) { + err := res.UnpackArchive() + if err != nil { + multierr = multierror.Append( + multierr, + fmt.Errorf("%s: %w", res.Identifier, err), + ) + } + } + } + + return multierr.ErrorOrNil() +} + +const ( + zipSuffix = ".zip" +) + +// UnpackArchive unpacks the archive the resource refers to. The contents are +// unpacked into a directory with the same name as the file, excluding the +// suffix. If the destination folder already exists, it is assumed that the +// contents have already been correctly unpacked. +func (res *Resource) UnpackArchive() error { + res.Lock() + defer res.Unlock() + + // Only unpack selected versions. + if res.SelectedVersion == nil { + return nil + } + + switch { + case strings.HasSuffix(res.Identifier, zipSuffix): + return res.unpackZipArchive() + default: + return fmt.Errorf("unsupported file type for unpacking") + } +} + +func (res *Resource) unpackZipArchive() error { + // Get file and directory paths. + archiveFile := res.SelectedVersion.storagePath() + destDir := strings.TrimSuffix(archiveFile, zipSuffix) + tmpDir := filepath.Join( + res.registry.tmpDir.Path, + filepath.FromSlash(strings.TrimSuffix( + path.Base(res.SelectedVersion.versionedPath()), + zipSuffix, + )), + ) + + // Check status of destination. + dstStat, err := os.Stat(destDir) + switch { + case errors.Is(err, fs.ErrNotExist): + // The destination does not exist, continue with unpacking. + case err != nil: + return fmt.Errorf("cannot access destination for unpacking: %w", err) + case !dstStat.IsDir(): + return fmt.Errorf("destination for unpacking is blocked by file: %s", dstStat.Name()) + default: + // Archive already seems to be unpacked. + return nil + } + + // Create the tmp directory for unpacking. + err = res.registry.tmpDir.EnsureAbsPath(tmpDir) + if err != nil { + return fmt.Errorf("failed to create tmp dir for unpacking: %w", err) + } + + // Defer clean up of directories. + defer func() { + // Always clean up the tmp dir. + _ = os.RemoveAll(tmpDir) + // Cleanup the destination in case of an error. + if err != nil { + _ = os.RemoveAll(destDir) + } + }() + + // Open the archive for reading. + var archiveReader *zip.ReadCloser + archiveReader, err = zip.OpenReader(archiveFile) + if err != nil { + return fmt.Errorf("failed to open zip reader: %w", err) + } + defer func() { + _ = archiveReader.Close() + }() + + // Save all files to the tmp dir. + for _, file := range archiveReader.File { + err = copyFromZipArchive( + file, + filepath.Join(tmpDir, filepath.FromSlash(file.Name)), + ) + if err != nil { + return fmt.Errorf("failed to extract archive file %s: %w", file.Name, err) + } + } + + // Make the final move. + err = os.Rename(tmpDir, destDir) + if err != nil { + return fmt.Errorf("failed to move the extracted archive from %s to %s: %w", tmpDir, destDir, err) + } + + // Fix permissions on the destination dir. + err = res.registry.storageDir.EnsureAbsPath(destDir) + if err != nil { + return fmt.Errorf("failed to apply directory permissions on %s: %w", destDir, err) + } + + log.Infof("%s: unpacked %s", res.registry.Name, res.SelectedVersion.versionedPath()) + return nil +} + +func copyFromZipArchive(archiveFile *zip.File, dstPath string) error { + // If file is a directory, create it and continue. + if archiveFile.FileInfo().IsDir() { + err := os.Mkdir(dstPath, archiveFile.Mode()) + if err != nil { + return fmt.Errorf("failed to create directory %s: %w", dstPath, err) + } + return nil + } + + // Open archived file for reading. + fileReader, err := archiveFile.Open() + if err != nil { + return fmt.Errorf("failed to open file in archive: %w", err) + } + defer func() { + _ = fileReader.Close() + }() + + // Open destination file for writing. + dstFile, err := os.OpenFile(dstPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, archiveFile.Mode()) + if err != nil { + return fmt.Errorf("failed to open destination file %s: %w", dstPath, err) + } + defer func() { + _ = dstFile.Close() + }() + + // Copy full file from archive to dst. + if _, err := io.CopyN(dstFile, fileReader, MaxUnpackSize); err != nil { + // EOF is expected here as the archive is likely smaller + // thane MaxUnpackSize + if errors.Is(err, io.EOF) { + return nil + } + return err + } + + return nil +} diff --git a/base/updater/updating.go b/base/updater/updating.go new file mode 100644 index 000000000..23e3df455 --- /dev/null +++ b/base/updater/updating.go @@ -0,0 +1,359 @@ +package updater + +import ( + "context" + "fmt" + "net/http" + "os" + "path" + "path/filepath" + "strings" + + "golang.org/x/exp/slices" + + "github.com/safing/jess/filesig" + "github.com/safing/jess/lhash" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/utils" +) + +// UpdateIndexes downloads all indexes. An error is only returned when all +// indexes fail to update. +func (reg *ResourceRegistry) UpdateIndexes(ctx context.Context) error { + var lastErr error + var anySuccess bool + + // Start registry operation. + reg.state.StartOperation(StateChecking) + defer reg.state.EndOperation() + + client := &http.Client{} + for _, idx := range reg.getIndexes() { + if err := reg.downloadIndex(ctx, client, idx); err != nil { + lastErr = err + log.Warningf("%s: failed to update index %s: %s", reg.Name, idx.Path, err) + } else { + anySuccess = true + } + } + + // If all indexes failed to update, fail. + if !anySuccess { + err := fmt.Errorf("failed to update all indexes, last error was: %w", lastErr) + reg.state.ReportUpdateCheck(nil, err) + return err + } + + // Get pending resources and update status. + pendingResourceVersions, _ := reg.GetPendingDownloads(true, false) + reg.state.ReportUpdateCheck( + humanInfoFromResourceVersions(pendingResourceVersions), + nil, + ) + + return nil +} + +func (reg *ResourceRegistry) downloadIndex(ctx context.Context, client *http.Client, idx *Index) error { + var ( + // Index. + indexErr error + indexData []byte + downloadURL string + + // Signature. + sigErr error + verifiedHash *lhash.LabeledHash + sigFileData []byte + verifOpts = reg.GetVerificationOptions(idx.Path) + ) + + // Upgrade to v2 index if verification is enabled. + downloadIndexPath := idx.Path + if verifOpts != nil { + downloadIndexPath = strings.TrimSuffix(downloadIndexPath, baseIndexExtension) + v2IndexExtension + } + + // Download new index and signature. + for tries := 0; tries < 3; tries++ { + // Index and signature need to be fetched together, so that they are + // fetched from the same source. One source should always have a matching + // index and signature. Backup sources may be behind a little. + // If the signature verification fails, another source should be tried. + + // Get index data. + indexData, downloadURL, indexErr = reg.fetchData(ctx, client, downloadIndexPath, tries) + if indexErr != nil { + log.Debugf("%s: failed to fetch index %s: %s", reg.Name, downloadURL, indexErr) + continue + } + + // Get signature and verify it. + if verifOpts != nil { + verifiedHash, sigFileData, sigErr = reg.fetchAndVerifySigFile( + ctx, client, + verifOpts, downloadIndexPath+filesig.Extension, nil, + tries, + ) + if sigErr != nil { + log.Debugf("%s: failed to verify signature of %s: %s", reg.Name, downloadURL, sigErr) + continue + } + + // Check if the index matches the verified hash. + if verifiedHash.Matches(indexData) { + log.Infof("%s: verified signature of %s", reg.Name, downloadURL) + } else { + sigErr = ErrIndexChecksumMismatch + log.Debugf("%s: checksum does not match file from %s", reg.Name, downloadURL) + continue + } + } + + break + } + if indexErr != nil { + return fmt.Errorf("failed to fetch index %s: %w", downloadIndexPath, indexErr) + } + if sigErr != nil { + return fmt.Errorf("failed to fetch or verify index %s signature: %w", downloadIndexPath, sigErr) + } + + // Parse the index file. + indexFile, err := ParseIndexFile(indexData, idx.Channel, idx.LastRelease) + if err != nil { + return fmt.Errorf("failed to parse index %s: %w", idx.Path, err) + } + + // Add index data to registry. + if len(indexFile.Releases) > 0 { + // Check if all resources are within the indexes' authority. + authoritativePath := path.Dir(idx.Path) + "/" + if authoritativePath == "./" { + // Fix path for indexes at the storage root. + authoritativePath = "" + } + cleanedData := make(map[string]string, len(indexFile.Releases)) + for key, version := range indexFile.Releases { + if strings.HasPrefix(key, authoritativePath) { + cleanedData[key] = version + } else { + log.Warningf("%s: index %s oversteps it's authority by defining version for %s", reg.Name, idx.Path, key) + } + } + + // add resources to registry + err = reg.AddResources(cleanedData, idx, false, true, idx.PreRelease) + if err != nil { + log.Warningf("%s: failed to add resources: %s", reg.Name, err) + } + } else { + log.Debugf("%s: index %s is empty", reg.Name, idx.Path) + } + + // Check if dest dir exists. + indexDir := filepath.FromSlash(path.Dir(idx.Path)) + err = reg.storageDir.EnsureRelPath(indexDir) + if err != nil { + log.Warningf("%s: failed to ensure directory for updated index %s: %s", reg.Name, idx.Path, err) + } + + // Index files must be readable by portmaster-staert with user permissions in order to load the index. + err = os.WriteFile( //nolint:gosec + filepath.Join(reg.storageDir.Path, filepath.FromSlash(idx.Path)), + indexData, 0o0644, + ) + if err != nil { + log.Warningf("%s: failed to save updated index %s: %s", reg.Name, idx.Path, err) + } + + // Write signature file, if we have one. + if len(sigFileData) > 0 { + err = os.WriteFile( //nolint:gosec + filepath.Join(reg.storageDir.Path, filepath.FromSlash(idx.Path)+filesig.Extension), + sigFileData, 0o0644, + ) + if err != nil { + log.Warningf("%s: failed to save updated index signature %s: %s", reg.Name, idx.Path+filesig.Extension, err) + } + } + + log.Infof("%s: updated index %s with %d entries", reg.Name, idx.Path, len(indexFile.Releases)) + return nil +} + +// DownloadUpdates checks if updates are available and downloads updates of used components. +func (reg *ResourceRegistry) DownloadUpdates(ctx context.Context, includeManual bool) error { + // Start registry operation. + reg.state.StartOperation(StateDownloading) + defer reg.state.EndOperation() + + // Get pending updates. + toUpdate, missingSigs := reg.GetPendingDownloads(includeManual, true) + downloadDetailsResources := humanInfoFromResourceVersions(toUpdate) + reg.state.UpdateOperationDetails(&StateDownloadingDetails{ + Resources: downloadDetailsResources, + }) + + // nothing to update + if len(toUpdate) == 0 && len(missingSigs) == 0 { + log.Infof("%s: everything up to date", reg.Name) + return nil + } + + // check download dir + if err := reg.tmpDir.Ensure(); err != nil { + return fmt.Errorf("could not prepare tmp directory for download: %w", err) + } + + // download updates + log.Infof("%s: starting to download %d updates", reg.Name, len(toUpdate)) + client := &http.Client{} + var reportError error + + for i, rv := range toUpdate { + log.Infof( + "%s: downloading update [%d/%d]: %s version %s", + reg.Name, + i+1, len(toUpdate), + rv.resource.Identifier, rv.VersionNumber, + ) + var err error + for tries := 0; tries < 3; tries++ { + err = reg.fetchFile(ctx, client, rv, tries) + if err == nil { + // Update resource version state. + rv.resource.Lock() + rv.Available = true + if rv.resource.VerificationOptions != nil { + rv.SigAvailable = true + } + rv.resource.Unlock() + + break + } + } + if err != nil { + reportError := fmt.Errorf("failed to download %s version %s: %w", rv.resource.Identifier, rv.VersionNumber, err) + log.Warningf("%s: %s", reg.Name, reportError) + } + + reg.state.UpdateOperationDetails(&StateDownloadingDetails{ + Resources: downloadDetailsResources, + FinishedUpTo: i + 1, + }) + } + + if len(missingSigs) > 0 { + log.Infof("%s: downloading %d missing signatures", reg.Name, len(missingSigs)) + + for _, rv := range missingSigs { + var err error + for tries := 0; tries < 3; tries++ { + err = reg.fetchMissingSig(ctx, client, rv, tries) + if err == nil { + // Update resource version state. + rv.resource.Lock() + rv.SigAvailable = true + rv.resource.Unlock() + + break + } + } + if err != nil { + reportError := fmt.Errorf("failed to download missing sig of %s version %s: %w", rv.resource.Identifier, rv.VersionNumber, err) + log.Warningf("%s: %s", reg.Name, reportError) + } + } + } + + reg.state.ReportDownloads( + downloadDetailsResources, + reportError, + ) + log.Infof("%s: finished downloading updates", reg.Name) + + return nil +} + +// DownloadUpdates checks if updates are available and downloads updates of used components. + +// GetPendingDownloads returns the list of pending downloads. +// If manual is set, indexes with AutoDownload=false will be checked. +// If auto is set, indexes with AutoDownload=true will be checked. +func (reg *ResourceRegistry) GetPendingDownloads(manual, auto bool) (resources, sigs []*ResourceVersion) { + reg.RLock() + defer reg.RUnlock() + + // create list of downloads + var toUpdate []*ResourceVersion + var missingSigs []*ResourceVersion + + for _, res := range reg.resources { + func() { + res.Lock() + defer res.Unlock() + + // Skip resources without index or indexes that should not be reported + // according to parameters. + switch { + case res.Index == nil: + // Cannot download if resource is not part of an index. + return + case manual && !res.Index.AutoDownload: + // Manual update report and index is not auto-download. + case auto && res.Index.AutoDownload: + // Auto update report and index is auto-download. + default: + // Resource should not be reported. + return + } + + // Skip resources we don't need. + switch { + case res.inUse(): + // Update if resource is in use. + case res.available(): + // Update if resource is available locally, ie. was used in the past. + case utils.StringInSlice(reg.MandatoryUpdates, res.Identifier): + // Update is set as mandatory. + default: + // Resource does not need to be updated. + return + } + + // Go through all versions until we find versions that need updating. + for _, rv := range res.Versions { + switch { + case !rv.CurrentRelease: + // We are not interested in older releases. + case !rv.Available: + // File not available locally, download! + toUpdate = append(toUpdate, rv) + case !rv.SigAvailable && res.VerificationOptions != nil: + // File signature is not available and verification is enabled, download signature! + missingSigs = append(missingSigs, rv) + } + } + }() + } + + slices.SortFunc[[]*ResourceVersion, *ResourceVersion](toUpdate, func(a, b *ResourceVersion) int { + return strings.Compare(a.resource.Identifier, b.resource.Identifier) + }) + slices.SortFunc[[]*ResourceVersion, *ResourceVersion](missingSigs, func(a, b *ResourceVersion) int { + return strings.Compare(a.resource.Identifier, b.resource.Identifier) + }) + + return toUpdate, missingSigs +} + +func humanInfoFromResourceVersions(resourceVersions []*ResourceVersion) []string { + identifiers := make([]string, len(resourceVersions)) + + for i, rv := range resourceVersions { + identifiers[i] = fmt.Sprintf("%s v%s", rv.resource.Identifier, rv.VersionNumber) + } + + return identifiers +} diff --git a/base/utils/atomic.go b/base/utils/atomic.go new file mode 100644 index 000000000..8f77f1d8e --- /dev/null +++ b/base/utils/atomic.go @@ -0,0 +1,105 @@ +package utils + +import ( + "errors" + "fmt" + "io" + "io/fs" + "os" + + "github.com/safing/portmaster/base/utils/renameio" +) + +// AtomicFileOptions holds additional options for manipulating +// the behavior of CreateAtomic and friends. +type AtomicFileOptions struct { + // Mode is the file mode for the new file. If + // 0, the file mode will be set to 0600. + Mode os.FileMode + + // TempDir is the path to the temp-directory + // that should be used. If empty, it defaults + // to the system temp. + TempDir string +} + +// CreateAtomic creates or overwrites a file at dest atomically using +// data from r. Atomic means that even in case of a power outage, +// dest will never be a zero-length file. It will always either contain +// the previous data (or not exist) or the new data but never anything +// in between. +func CreateAtomic(dest string, r io.Reader, opts *AtomicFileOptions) error { + if opts == nil { + opts = &AtomicFileOptions{} + } + + tmpFile, err := renameio.TempFile(opts.TempDir, dest) + if err != nil { + return fmt.Errorf("failed to create temp file: %w", err) + } + defer tmpFile.Cleanup() //nolint:errcheck + + if opts.Mode != 0 { + if err := tmpFile.Chmod(opts.Mode); err != nil { + return fmt.Errorf("failed to update mode bits of temp file: %w", err) + } + } + + if _, err := io.Copy(tmpFile, r); err != nil { + return fmt.Errorf("failed to copy source file: %w", err) + } + + if err := tmpFile.CloseAtomicallyReplace(); err != nil { + return fmt.Errorf("failed to rename temp file to %q", dest) + } + + return nil +} + +// CopyFileAtomic is like CreateAtomic but copies content from +// src to dest. If opts.Mode is 0 CopyFileAtomic tries to set +// the file mode of src to dest. +func CopyFileAtomic(dest string, src string, opts *AtomicFileOptions) error { + if opts == nil { + opts = &AtomicFileOptions{} + } + + if opts.Mode == 0 { + stat, err := os.Stat(src) + if err != nil { + return err + } + opts.Mode = stat.Mode() + } + + f, err := os.Open(src) + if err != nil { + return err + } + defer func() { + _ = f.Close() + }() + + return CreateAtomic(dest, f, opts) +} + +// ReplaceFileAtomic replaces the file at dest with the content from src. +// If dest exists it's file mode copied and used for the replacement. If +// not, dest will get the same file mode as src. See CopyFileAtomic and +// CreateAtomic for more information. +func ReplaceFileAtomic(dest string, src string, opts *AtomicFileOptions) error { + if opts == nil { + opts = &AtomicFileOptions{} + } + + if opts.Mode == 0 { + stat, err := os.Stat(dest) + if err == nil { + opts.Mode = stat.Mode() + } else if !errors.Is(err, fs.ErrNotExist) { + return err + } + } + + return CopyFileAtomic(dest, src, opts) +} diff --git a/base/utils/broadcastflag.go b/base/utils/broadcastflag.go new file mode 100644 index 000000000..ea6c7a487 --- /dev/null +++ b/base/utils/broadcastflag.go @@ -0,0 +1,84 @@ +package utils + +import ( + "sync" + + "github.com/tevino/abool" +) + +// BroadcastFlag is a simple system to broadcast a flag value. +type BroadcastFlag struct { + flag *abool.AtomicBool + signal chan struct{} + lock sync.Mutex +} + +// Flag receives changes from its broadcasting flag. +// A Flag must only be used in one goroutine and is not concurrency safe, +// but fast. +type Flag struct { + flag *abool.AtomicBool + signal chan struct{} + broadcaster *BroadcastFlag +} + +// NewBroadcastFlag returns a new BroadcastFlag. +// In the initial state, the flag is not set and the signal does not trigger. +func NewBroadcastFlag() *BroadcastFlag { + return &BroadcastFlag{ + flag: abool.New(), + signal: make(chan struct{}), + lock: sync.Mutex{}, + } +} + +// NewFlag returns a new Flag that listens to this broadcasting flag. +// In the initial state, the flag is set and the signal triggers. +// You can call Refresh immediately to get the current state from the +// broadcasting flag. +func (bf *BroadcastFlag) NewFlag() *Flag { + newFlag := &Flag{ + flag: abool.NewBool(true), + signal: make(chan struct{}), + broadcaster: bf, + } + close(newFlag.signal) + return newFlag +} + +// NotifyAndReset notifies all flags of this broadcasting flag and resets the +// internal broadcast flag state. +func (bf *BroadcastFlag) NotifyAndReset() { + bf.lock.Lock() + defer bf.lock.Unlock() + + // Notify all flags of the change. + bf.flag.Set() + close(bf.signal) + + // Reset + bf.flag = abool.New() + bf.signal = make(chan struct{}) +} + +// Signal returns a channel that waits for the flag to be set. This does not +// reset the Flag itself, you'll need to call Refresh for that. +func (f *Flag) Signal() <-chan struct{} { + return f.signal +} + +// IsSet returns whether the flag was set since the last Refresh. +// This does not reset the Flag itself, you'll need to call Refresh for that. +func (f *Flag) IsSet() bool { + return f.flag.IsSet() +} + +// Refresh fetches the current state from the broadcasting flag. +func (f *Flag) Refresh() { + f.broadcaster.lock.Lock() + defer f.broadcaster.lock.Unlock() + + // Copy current flag and signal from the broadcasting flag. + f.flag = f.broadcaster.flag + f.signal = f.broadcaster.signal +} diff --git a/base/utils/call_limiter.go b/base/utils/call_limiter.go new file mode 100644 index 000000000..eb669a2f5 --- /dev/null +++ b/base/utils/call_limiter.go @@ -0,0 +1,87 @@ +package utils + +import ( + "sync" + "sync/atomic" + "time" +) + +// CallLimiter bundles concurrent calls and optionally limits how fast a function is called. +type CallLimiter struct { + pause time.Duration + + inLock sync.Mutex + lastExec time.Time + + waiters atomic.Int32 + outLock sync.Mutex +} + +// NewCallLimiter returns a new call limiter. +// Set minPause to zero to disable the minimum pause between calls. +func NewCallLimiter(minPause time.Duration) *CallLimiter { + return &CallLimiter{ + pause: minPause, + } +} + +// Do executes the given function. +// All concurrent calls to Do are bundled and return when f() finishes. +// Waits until the minimum pause is over before executing f() again. +func (l *CallLimiter) Do(f func()) { + // Wait for the previous waiters to exit. + l.inLock.Lock() + + // Defer final unlock to safeguard from panics. + defer func() { + // Execution is finished - leave. + // If we are the last waiter, let the next batch in. + if l.waiters.Add(-1) == 0 { + l.inLock.Unlock() + } + }() + + // Check if we are the first waiter. + if l.waiters.Add(1) == 1 { + // Take the lead on this execution run. + l.lead(f) + } else { + // We are not the first waiter, let others in. + l.inLock.Unlock() + } + + // Wait for execution to complete. + l.outLock.Lock() + l.outLock.Unlock() //nolint:staticcheck + + // Last statement is in defer above. +} + +func (l *CallLimiter) lead(f func()) { + // Make all others wait while we execute the function. + l.outLock.Lock() + + // Unlock in lock until execution is finished. + l.inLock.Unlock() + + // Transition from out lock to in lock when done. + defer func() { + // Update last execution time. + l.lastExec = time.Now().UTC() + // Stop newcomers from waiting on previous execution. + l.inLock.Lock() + // Allow waiters to leave. + l.outLock.Unlock() + }() + + // Wait for the minimum duration between executions. + if l.pause > 0 { + sinceLastExec := time.Since(l.lastExec) + if sinceLastExec < l.pause { + time.Sleep(l.pause - sinceLastExec) + } + } + + // Execute. + f() +} diff --git a/base/utils/call_limiter_test.go b/base/utils/call_limiter_test.go new file mode 100644 index 000000000..16bd1d5ab --- /dev/null +++ b/base/utils/call_limiter_test.go @@ -0,0 +1,91 @@ +package utils + +import ( + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/tevino/abool" +) + +func TestCallLimiter(t *testing.T) { + t.Parallel() + + pause := 10 * time.Millisecond + oa := NewCallLimiter(pause) + executed := abool.New() + var testWg sync.WaitGroup + + // One execution should gobble up the whole batch. + // We are doing this without sleep in function, so dummy exec first to trigger first pause. + oa.Do(func() {}) + // Start + for i := 0; i < 10; i++ { + testWg.Add(100) + for i := 0; i < 100; i++ { + go func() { + oa.Do(func() { + if !executed.SetToIf(false, true) { + t.Errorf("concurrent execution!") + } + }) + testWg.Done() + }() + } + testWg.Wait() + // Check if function was executed at least once. + if executed.IsNotSet() { + t.Errorf("no execution!") + } + executed.UnSet() // reset check + } + + // Wait for pause to reset. + time.Sleep(pause) + + // Continuous use with re-execution. + // Choose values so that about 10 executions are expected + var execs uint32 + testWg.Add(200) + for i := 0; i < 200; i++ { + go func() { + oa.Do(func() { + atomic.AddUint32(&execs, 1) + time.Sleep(10 * time.Millisecond) + }) + testWg.Done() + }() + + // Start one goroutine every 1ms. + time.Sleep(1 * time.Millisecond) + } + + testWg.Wait() + if execs <= 5 { + t.Errorf("unexpected low exec count: %d", execs) + } + if execs >= 15 { + t.Errorf("unexpected high exec count: %d", execs) + } + + // Wait for pause to reset. + time.Sleep(pause) + + // Check if the limiter correctly handles panics. + testWg.Add(100) + for i := 0; i < 100; i++ { + go func() { + defer func() { + _ = recover() + testWg.Done() + }() + oa.Do(func() { + time.Sleep(1 * time.Millisecond) + panic("test") + }) + }() + time.Sleep(100 * time.Microsecond) + } + testWg.Wait() +} diff --git a/base/utils/debug/debug.go b/base/utils/debug/debug.go new file mode 100644 index 000000000..06ac7b937 --- /dev/null +++ b/base/utils/debug/debug.go @@ -0,0 +1,148 @@ +package debug + +import ( + "bytes" + "fmt" + "runtime/pprof" + "strings" + "time" + + "github.com/safing/portmaster/base/info" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" +) + +// Info gathers debugging information and stores everything in a buffer in +// order to write it to somewhere later. It directly inherits a bytes.Buffer, +// so you can also use all these functions too. +type Info struct { + bytes.Buffer + Style string +} + +// InfoFlag defines possible options for adding sections to a Info. +type InfoFlag int + +const ( + // NoFlags does nothing. + NoFlags InfoFlag = 0 + + // UseCodeSection wraps the section content in a markdown code section. + UseCodeSection InfoFlag = 1 + + // AddContentLineBreaks adds a line breaks after each line of content, + // except for the last. + AddContentLineBreaks InfoFlag = 2 +) + +func useCodeSection(flags InfoFlag) bool { + return flags&UseCodeSection > 0 +} + +func addContentLineBreaks(flags InfoFlag) bool { + return flags&AddContentLineBreaks > 0 +} + +// AddSection adds a debug section to the Info. The result is directly +// written into the buffer. +func (di *Info) AddSection(name string, flags InfoFlag, content ...string) { + // Check if we need a spacer. + if di.Len() > 0 { + _, _ = di.WriteString("\n\n") + } + + // Write section to buffer. + + // Write section header. + if di.Style == "github" { + _, _ = di.WriteString(fmt.Sprintf("
\n%s\n\n", name)) + } else { + _, _ = di.WriteString(fmt.Sprintf("**%s**:\n\n", name)) + } + + // Write section content. + if useCodeSection(flags) { + // Write code header: Needs one empty line between previous data. + _, _ = di.WriteString("```\n") + } + for i, part := range content { + _, _ = di.WriteString(part) + if addContentLineBreaks(flags) && i < len(content)-1 { + _, _ = di.WriteString("\n") + } + } + if useCodeSection(flags) { + // Write code footer: Needs one empty line between next data. + _, _ = di.WriteString("\n```\n") + } + + // Write section header. + if di.Style == "github" { + _, _ = di.WriteString("\n
") + } +} + +// AddVersionInfo adds version information from the info pkg. +func (di *Info) AddVersionInfo() { + di.AddSection( + "Version "+info.Version(), + UseCodeSection, + info.FullVersion(), + ) +} + +// AddGoroutineStack adds the current goroutine stack. +func (di *Info) AddGoroutineStack() { + buf := new(bytes.Buffer) + err := pprof.Lookup("goroutine").WriteTo(buf, 1) + if err != nil { + di.AddSection( + "Goroutine Stack", + NoFlags, + fmt.Sprintf("Failed to get: %s", err), + ) + return + } + + // Add section. + di.AddSection( + "Goroutine Stack", + UseCodeSection, + buf.String(), + ) +} + +// AddLastReportedModuleError adds the last reported module error, if one exists. +func (di *Info) AddLastReportedModuleError() { + me := modules.GetLastReportedError() + if me == nil { + di.AddSection("No Module Error", NoFlags) + return + } + + di.AddSection( + fmt.Sprintf("%s Module Error", strings.Title(me.ModuleName)), //nolint:staticcheck + UseCodeSection, + me.Format(), + ) +} + +// AddLastUnexpectedLogs adds the last 10 unexpected log lines, if any. +func (di *Info) AddLastUnexpectedLogs() { + lines := log.GetLastUnexpectedLogs() + + // Check if there is anything at all. + if len(lines) == 0 { + di.AddSection("No Unexpected Logs", NoFlags) + return + } + + di.AddSection( + "Unexpected Logs", + UseCodeSection|AddContentLineBreaks, + append( + lines, + fmt.Sprintf("%s CURRENT TIME", time.Now().Format("060102 15:04:05.000")), + )..., + ) +} diff --git a/base/utils/debug/debug_android.go b/base/utils/debug/debug_android.go new file mode 100644 index 000000000..265702e7f --- /dev/null +++ b/base/utils/debug/debug_android.go @@ -0,0 +1,31 @@ +package debug + +import ( + "context" + "fmt" + + "github.com/safing/portmaster-android/go/app_interface" +) + +// AddPlatformInfo adds OS and platform information. +func (di *Info) AddPlatformInfo(_ context.Context) { + // Get information from the system. + info, err := app_interface.GetPlatformInfo() + if err != nil { + di.AddSection( + "Platform Information", + NoFlags, + fmt.Sprintf("Failed to get: %s", err), + ) + return + } + + // Add section. + di.AddSection( + fmt.Sprintf("Platform: Android"), + UseCodeSection|AddContentLineBreaks, + fmt.Sprintf("SDK: %d", info.SDK), + fmt.Sprintf("Device: %s %s (%s)", info.Manufacturer, info.Brand, info.Board), + fmt.Sprintf("App: %s: %s %s", info.ApplicationID, info.VersionName, info.BuildType)) + +} diff --git a/base/utils/debug/debug_default.go b/base/utils/debug/debug_default.go new file mode 100644 index 000000000..cb429b816 --- /dev/null +++ b/base/utils/debug/debug_default.go @@ -0,0 +1,43 @@ +//go:build !android + +package debug + +import ( + "context" + "fmt" + + "github.com/shirou/gopsutil/host" +) + +// AddPlatformInfo adds OS and platform information. +func (di *Info) AddPlatformInfo(ctx context.Context) { + // Get information from the system. + info, err := host.InfoWithContext(ctx) + if err != nil { + di.AddSection( + "Platform Information", + NoFlags, + fmt.Sprintf("Failed to get: %s", err), + ) + return + } + + // Check if we want to add virtulization information. + var virtInfo string + if info.VirtualizationRole == "guest" { + if info.VirtualizationSystem != "" { + virtInfo = fmt.Sprintf("VM: %s", info.VirtualizationSystem) + } else { + virtInfo = "VM: unidentified" + } + } + + // Add section. + di.AddSection( + fmt.Sprintf("Platform: %s %s", info.Platform, info.PlatformVersion), + UseCodeSection|AddContentLineBreaks, + fmt.Sprintf("System: %s %s (%s) %s", info.Platform, info.OS, info.PlatformFamily, info.PlatformVersion), + fmt.Sprintf("Kernel: %s %s", info.KernelVersion, info.KernelArch), + virtInfo, + ) +} diff --git a/base/utils/fs.go b/base/utils/fs.go new file mode 100644 index 000000000..b612a0699 --- /dev/null +++ b/base/utils/fs.go @@ -0,0 +1,51 @@ +package utils + +import ( + "errors" + "fmt" + "io/fs" + "os" + "runtime" +) + +const isWindows = runtime.GOOS == "windows" + +// EnsureDirectory ensures that the given directory exists and that is has the given permissions set. +// If path is a file, it is deleted and a directory created. +func EnsureDirectory(path string, perm os.FileMode) error { + // open path + f, err := os.Stat(path) + if err == nil { + // file exists + if f.IsDir() { + // directory exists, check permissions + if isWindows { + // TODO: set correct permission on windows + // acl.Chmod(path, perm) + } else if f.Mode().Perm() != perm { + return os.Chmod(path, perm) + } + return nil + } + err = os.Remove(path) + if err != nil { + return fmt.Errorf("could not remove file %s to place dir: %w", path, err) + } + } + // file does not exist (or has been deleted) + if err == nil || errors.Is(err, fs.ErrNotExist) { + err = os.Mkdir(path, perm) + if err != nil { + return fmt.Errorf("could not create dir %s: %w", path, err) + } + return os.Chmod(path, perm) + } + // other error opening path + return fmt.Errorf("failed to access %s: %w", path, err) +} + +// PathExists returns whether the given path (file or dir) exists. +func PathExists(path string) bool { + _, err := os.Stat(path) + return err == nil || errors.Is(err, fs.ErrExist) +} diff --git a/base/utils/mimetypes.go b/base/utils/mimetypes.go new file mode 100644 index 000000000..dbf55f181 --- /dev/null +++ b/base/utils/mimetypes.go @@ -0,0 +1,78 @@ +package utils + +import "strings" + +// Do not depend on the OS for mimetypes. +// A Windows update screwed us over here and broke all the automatic mime +// typing via Go in April 2021. + +// MimeTypeByExtension returns a mimetype for the given file name extension, +// which must including the leading dot. +// If the extension is not known, the call returns with ok=false and, +// additionally, a default "application/octet-stream" mime type is returned. +func MimeTypeByExtension(ext string) (mimeType string, ok bool) { + mimeType, ok = mimeTypes[strings.ToLower(ext)] + if ok { + return + } + + return defaultMimeType, false +} + +var ( + defaultMimeType = "application/octet-stream" + + mimeTypes = map[string]string{ + ".7z": "application/x-7z-compressed", + ".atom": "application/atom+xml", + ".css": "text/css; charset=utf-8", + ".csv": "text/csv; charset=utf-8", + ".deb": "application/x-debian-package", + ".epub": "application/epub+zip", + ".es": "application/ecmascript", + ".flv": "video/x-flv", + ".gif": "image/gif", + ".gz": "application/gzip", + ".htm": "text/html; charset=utf-8", + ".html": "text/html; charset=utf-8", + ".jpeg": "image/jpeg", + ".jpg": "image/jpeg", + ".js": "text/javascript; charset=utf-8", + ".json": "application/json; charset=utf-8", + ".m3u": "audio/mpegurl", + ".m4a": "audio/mpeg", + ".md": "text/markdown; charset=utf-8", + ".mjs": "text/javascript; charset=utf-8", + ".mov": "video/quicktime", + ".mp3": "audio/mpeg", + ".mp4": "video/mp4", + ".mpeg": "video/mpeg", + ".mpg": "video/mpeg", + ".ogg": "audio/ogg", + ".ogv": "video/ogg", + ".otf": "font/otf", + ".pdf": "application/pdf", + ".png": "image/png", + ".qt": "video/quicktime", + ".rar": "application/rar", + ".rtf": "application/rtf", + ".svg": "image/svg+xml", + ".tar": "application/x-tar", + ".tiff": "image/tiff", + ".ts": "video/MP2T", + ".ttc": "font/collection", + ".ttf": "font/ttf", + ".txt": "text/plain; charset=utf-8", + ".wasm": "application/wasm", + ".wav": "audio/x-wav", + ".webm": "video/webm", + ".webp": "image/webp", + ".woff": "font/woff", + ".woff2": "font/woff2", + ".xml": "text/xml; charset=utf-8", + ".xz": "application/x-xz", + ".yaml": "application/yaml; charset=utf-8", + ".yml": "application/yaml; charset=utf-8", + ".zip": "application/zip", + } +) diff --git a/base/utils/onceagain.go b/base/utils/onceagain.go new file mode 100644 index 000000000..3c4af5c62 --- /dev/null +++ b/base/utils/onceagain.go @@ -0,0 +1,86 @@ +package utils + +// This file is forked from https://github.com/golang/go/blob/bc593eac2dc63d979a575eccb16c7369a5ff81e0/src/sync/once.go. + +import ( + "sync" + "sync/atomic" +) + +// OnceAgain is an object that will perform only one action "in flight". It's +// basically the same as sync.Once, but is automatically reused when the +// function was executed and everyone who waited has left. +// Important: This is somewhat racy when used heavily as it only resets _after_ +// everyone who waited has left. So, while some goroutines are waiting to be +// activated again to leave the waiting state, other goroutines will call Do() +// without executing the function again. +type OnceAgain struct { + // done indicates whether the action has been performed. + // It is first in the struct because it is used in the hot path. + // The hot path is inlined at every call site. + // Placing done first allows more compact instructions on some architectures (amd64/x86), + // and fewer instructions (to calculate offset) on other architectures. + done uint32 + + // Number of waiters waiting for the function to finish. The last waiter resets done. + waiters int32 + + m sync.Mutex +} + +// Do calls the function f if and only if Do is being called for the +// first time for this instance of Once. In other words, given +// +// var once Once +// +// if once.Do(f) is called multiple times, only the first call will invoke f, +// even if f has a different value in each invocation. A new instance of +// Once is required for each function to execute. +// +// Do is intended for initialization that must be run exactly once. Since f +// is niladic, it may be necessary to use a function literal to capture the +// arguments to a function to be invoked by Do: +// +// config.once.Do(func() { config.init(filename) }) +// +// Because no call to Do returns until the one call to f returns, if f causes +// Do to be called, it will deadlock. +// +// If f panics, Do considers it to have returned; future calls of Do return +// without calling f. +func (o *OnceAgain) Do(f func()) { + // Note: Here is an incorrect implementation of Do: + // + // if atomic.CompareAndSwapUint32(&o.done, 0, 1) { + // f() + // } + // + // Do guarantees that when it returns, f has finished. + // This implementation would not implement that guarantee: + // given two simultaneous calls, the winner of the cas would + // call f, and the second would return immediately, without + // waiting for the first's call to f to complete. + // This is why the slow path falls back to a mutex, and why + // the atomic.StoreUint32 must be delayed until after f returns. + + if atomic.LoadUint32(&o.done) == 0 { + // Outlined slow-path to allow inlining of the fast-path. + o.doSlow(f) + } +} + +func (o *OnceAgain) doSlow(f func()) { + atomic.AddInt32(&o.waiters, 1) + defer func() { + if atomic.AddInt32(&o.waiters, -1) == 0 { + atomic.StoreUint32(&o.done, 0) // reset + } + }() + + o.m.Lock() + defer o.m.Unlock() + if o.done == 0 { + defer atomic.StoreUint32(&o.done, 1) + f() + } +} diff --git a/base/utils/onceagain_test.go b/base/utils/onceagain_test.go new file mode 100644 index 000000000..5d4e4aaff --- /dev/null +++ b/base/utils/onceagain_test.go @@ -0,0 +1,60 @@ +package utils + +import ( + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/tevino/abool" +) + +func TestOnceAgain(t *testing.T) { + t.Parallel() + + oa := OnceAgain{} + executed := abool.New() + var testWg sync.WaitGroup + + // One execution should gobble up the whole batch. + for i := 0; i < 10; i++ { + testWg.Add(100) + for i := 0; i < 100; i++ { + go func() { + oa.Do(func() { + if !executed.SetToIf(false, true) { + t.Errorf("concurrent execution!") + } + time.Sleep(10 * time.Millisecond) + }) + testWg.Done() + }() + } + testWg.Wait() + executed.UnSet() // reset check + } + + // Continuous use with re-execution. + // Choose values so that about 10 executions are expected + var execs uint32 + testWg.Add(100) + for i := 0; i < 100; i++ { + go func() { + oa.Do(func() { + atomic.AddUint32(&execs, 1) + time.Sleep(10 * time.Millisecond) + }) + testWg.Done() + }() + + time.Sleep(1 * time.Millisecond) + } + + testWg.Wait() + if execs <= 8 { + t.Errorf("unexpected low exec count: %d", execs) + } + if execs >= 12 { + t.Errorf("unexpected high exec count: %d", execs) + } +} diff --git a/base/utils/osdetail/colors_windows.go b/base/utils/osdetail/colors_windows.go new file mode 100644 index 000000000..9a1ad7da0 --- /dev/null +++ b/base/utils/osdetail/colors_windows.go @@ -0,0 +1,51 @@ +package osdetail + +import ( + "sync" + + "golang.org/x/sys/windows" +) + +var ( + colorSupport bool + + colorSupportChecked bool + checkingColorSupport sync.Mutex +) + +// EnableColorSupport tries to enable color support for cmd on windows and returns whether it is enabled. +func EnableColorSupport() bool { + checkingColorSupport.Lock() + defer checkingColorSupport.Unlock() + + if !colorSupportChecked { + colorSupport = enableColorSupport() + colorSupportChecked = true + } + return colorSupport +} + +func enableColorSupport() bool { + if IsAtLeastWindowsNTVersionWithDefault("10", false) { + + // check if windows.Stdout is file + if windows.GetFileInformationByHandle(windows.Stdout, &windows.ByHandleFileInformation{}) == nil { + return false + } + + var mode uint32 + err := windows.GetConsoleMode(windows.Stdout, &mode) + if err == nil { + if mode&windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING == 0 { + mode |= windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING + err = windows.SetConsoleMode(windows.Stdout, mode) + if err != nil { + return false + } + } + return true + } + } + + return false +} diff --git a/base/utils/osdetail/command.go b/base/utils/osdetail/command.go new file mode 100644 index 000000000..9285e3626 --- /dev/null +++ b/base/utils/osdetail/command.go @@ -0,0 +1,51 @@ +package osdetail + +import ( + "bytes" + "errors" + "os/exec" + "strings" +) + +// RunCmd runs the given command and run error checks on the output. +func RunCmd(command ...string) (output []byte, err error) { + // Create command to execute. + var cmd *exec.Cmd + switch len(command) { + case 0: + return nil, errors.New("no command supplied") + case 1: + cmd = exec.Command(command[0]) + default: + cmd = exec.Command(command[0], command[1:]...) + } + + // Create and assign output buffers. + var stdoutBuf bytes.Buffer + var stderrBuf bytes.Buffer + cmd.Stdout = &stdoutBuf + cmd.Stderr = &stderrBuf + + // Run command and collect output. + err = cmd.Run() + stdout, stderr := stdoutBuf.Bytes(), stderrBuf.Bytes() + if err != nil { + return nil, err + } + // Command might not return an error, but just write to stdout instead. + if len(stderr) > 0 { + return nil, errors.New(strings.SplitN(string(stderr), "\n", 2)[0]) + } + + // Debugging output: + // fmt.Printf("command stdout: %s\n", stdout) + // fmt.Printf("command stderr: %s\n", stderr) + + // Finalize stdout. + cleanedOutput := bytes.TrimSpace(stdout) + if len(cleanedOutput) == 0 { + return nil, ErrEmptyOutput + } + + return cleanedOutput, nil +} diff --git a/base/utils/osdetail/dnscache_windows.go b/base/utils/osdetail/dnscache_windows.go new file mode 100644 index 000000000..dfe364a2f --- /dev/null +++ b/base/utils/osdetail/dnscache_windows.go @@ -0,0 +1,17 @@ +package osdetail + +import ( + "os/exec" +) + +// EnableDNSCache enables the Windows Service "DNS Client" by setting the registry value "HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\services\Dnscache" to 2 (Automatic). +// A reboot is required for this setting to take effect. +func EnableDNSCache() error { + return exec.Command("reg", "add", "HKEY_LOCAL_MACHINE\\SYSTEM\\CurrentControlSet\\services\\Dnscache", "/v", "Start", "/t", "REG_DWORD", "/d", "2", "/f").Run() +} + +// DisableDNSCache disables the Windows Service "DNS Client" by setting the registry value "HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\services\Dnscache" to 4 (Disabled). +// A reboot is required for this setting to take effect. +func DisableDNSCache() error { + return exec.Command("reg", "add", "HKEY_LOCAL_MACHINE\\SYSTEM\\CurrentControlSet\\services\\Dnscache", "/v", "Start", "/t", "REG_DWORD", "/d", "4", "/f").Run() +} diff --git a/base/utils/osdetail/errors.go b/base/utils/osdetail/errors.go new file mode 100644 index 000000000..b79ce11e5 --- /dev/null +++ b/base/utils/osdetail/errors.go @@ -0,0 +1,12 @@ +package osdetail + +import "errors" + +var ( + // ErrNotSupported is returned when an operation is not supported on the current platform. + ErrNotSupported = errors.New("not supported") + // ErrNotFound is returned when the desired data is not found. + ErrNotFound = errors.New("not found") + // ErrEmptyOutput is a special error that is returned when an operation has no error, but also returns to data. + ErrEmptyOutput = errors.New("command succeeded with empty output") +) diff --git a/base/utils/osdetail/service_windows.go b/base/utils/osdetail/service_windows.go new file mode 100644 index 000000000..5774db041 --- /dev/null +++ b/base/utils/osdetail/service_windows.go @@ -0,0 +1,112 @@ +package osdetail + +import ( + "errors" + "fmt" + "os/exec" + "strings" + "time" +) + +// Service Status +const ( + StatusUnknown uint8 = iota + StatusRunningStoppable + StatusRunningNotStoppable + StatusStartPending + StatusStopPending + StatusStopped +) + +// Exported errors +var ( + ErrServiceNotStoppable = errors.New("the service is not stoppable") +) + +// GetServiceStatus returns the current status of a Windows Service (limited implementation). +func GetServiceStatus(name string) (status uint8, err error) { + + output, err := exec.Command("sc", "query", name).Output() + if err != nil { + return StatusUnknown, fmt.Errorf("failed to query service: %s", err) + } + outputString := string(output) + + switch { + case strings.Contains(outputString, "RUNNING"): + if strings.Contains(outputString, "NOT_STOPPABLE") { + return StatusRunningNotStoppable, nil + } + return StatusRunningStoppable, nil + case strings.Contains(outputString, "STOP_PENDING"): + return StatusStopPending, nil + case strings.Contains(outputString, "STOPPED"): + return StatusStopped, nil + case strings.Contains(outputString, "START_PENDING"): + return StatusStopPending, nil + } + + return StatusUnknown, errors.New("unknown service status") +} + +// StopService stops a Windows Service. +func StopService(name string) (err error) { + pendingCnt := 0 + for { + + // get status + status, err := GetServiceStatus(name) + if err != nil { + return err + } + + switch status { + case StatusRunningStoppable: + err := exec.Command("sc", "stop", name).Run() + if err != nil { + return fmt.Errorf("failed to stop service: %s", err) + } + case StatusRunningNotStoppable: + return ErrServiceNotStoppable + case StatusStartPending, StatusStopPending: + pendingCnt++ + if pendingCnt > 50 { + return errors.New("service stuck in pending status (5s)") + } + case StatusStopped: + return nil + } + + time.Sleep(100 * time.Millisecond) + } +} + +// SartService starts a Windows Service. +func SartService(name string) (err error) { + pendingCnt := 0 + for { + + // get status + status, err := GetServiceStatus(name) + if err != nil { + return err + } + + switch status { + case StatusRunningStoppable, StatusRunningNotStoppable: + return nil + case StatusStartPending, StatusStopPending: + pendingCnt++ + if pendingCnt > 50 { + return errors.New("service stuck in pending status (5s)") + } + case StatusStopped: + err := exec.Command("sc", "start", name).Run() + if err != nil { + return fmt.Errorf("failed to stop service: %s", err) + } + } + + time.Sleep(100 * time.Millisecond) + } +} diff --git a/base/utils/osdetail/shell_windows.go b/base/utils/osdetail/shell_windows.go new file mode 100644 index 000000000..926b7c8ab --- /dev/null +++ b/base/utils/osdetail/shell_windows.go @@ -0,0 +1,49 @@ +package osdetail + +import ( + "bytes" + "errors" +) + +// RunPowershellCmd runs a powershell command and returns its output. +func RunPowershellCmd(script string) (output []byte, err error) { + // Create command to execute. + return RunCmd( + "powershell.exe", + "-ExecutionPolicy", "Bypass", + "-NoProfile", + "-NonInteractive", + "[System.Console]::OutputEncoding = [System.Text.Encoding]::UTF8\n"+script, + ) +} + +const outputSeparator = "pwzzhtuvpwdgozhzbnjj" + +// RunTerminalCmd runs a Windows cmd command and returns its output. +// It sets the output of the cmd to UTF-8 in order to avoid encoding errors. +func RunTerminalCmd(command ...string) (output []byte, err error) { + output, err = RunCmd(append([]string{ + "cmd.exe", + "/c", + "chcp", // Set output encoding... + "65001", // ...to UTF-8. + "&", + "echo", + outputSeparator, + "&", + }, + command..., + )...) + if err != nil { + return nil, err + } + + // Find correct start of output and shift start. + index := bytes.IndexAny(output, outputSeparator+"\r\n") + if index < 0 { + return nil, errors.New("failed to post-process output: could not find output separator") + } + output = output[index+len(outputSeparator)+2:] + + return output, nil +} diff --git a/base/utils/osdetail/svchost_windows.go b/base/utils/osdetail/svchost_windows.go new file mode 100644 index 000000000..3b7cd4f94 --- /dev/null +++ b/base/utils/osdetail/svchost_windows.go @@ -0,0 +1,120 @@ +package osdetail + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "os/exec" + "strconv" + "strings" + "sync" +) + +var ( + serviceNames map[int32][]string + serviceNamesLock sync.Mutex +) + +// Errors +var ( + ErrServiceNotFound = errors.New("no service with the given PID was found") +) + +// GetServiceNames returns all service names assosicated with a svchost.exe process on Windows. +func GetServiceNames(pid int32) ([]string, error) { + serviceNamesLock.Lock() + defer serviceNamesLock.Unlock() + + if serviceNames != nil { + names, ok := serviceNames[pid] + if ok { + return names, nil + } + } + + serviceNames, err := GetAllServiceNames() + if err != nil { + return nil, err + } + + names, ok := serviceNames[pid] + if ok { + return names, nil + } + + return nil, ErrServiceNotFound +} + +// GetAllServiceNames returns a list of service names assosicated with svchost.exe processes on Windows. +func GetAllServiceNames() (map[int32][]string, error) { + output, err := exec.Command("tasklist", "/svc", "/fi", "imagename eq svchost.exe").Output() + if err != nil { + return nil, fmt.Errorf("failed to get svchost tasklist: %s", err) + } + + // file scanner + scanner := bufio.NewScanner(bytes.NewReader(output)) + scanner.Split(bufio.ScanLines) + + // skip output header + for scanner.Scan() { + if strings.HasPrefix(scanner.Text(), "=") { + break + } + } + + var ( + pid int32 + services []string + collection = make(map[int32][]string) + ) + + for scanner.Scan() { + // get fields of line + fields := strings.Fields(scanner.Text()) + + // check fields length + if len(fields) == 0 { + continue + } + + // new entry + if fields[0] == "svchost.exe" { + // save old entry + if pid != 0 { + collection[pid] = services + } + // reset PID + pid = 0 + services = make([]string, 0, len(fields)) + + // check fields length + if len(fields) < 3 { + continue + } + + // get pid + i, err := strconv.ParseInt(fields[1], 10, 32) + if err != nil { + continue + } + pid = int32(i) + + // skip used fields + fields = fields[2:] + } + + // add service names + for _, field := range fields { + services = append(services, strings.Trim(strings.TrimSpace(field), ",")) + } + } + + if pid != 0 { + // save last entry + collection[pid] = services + } + + return collection, nil +} diff --git a/base/utils/osdetail/version_windows.go b/base/utils/osdetail/version_windows.go new file mode 100644 index 000000000..c8db11c6a --- /dev/null +++ b/base/utils/osdetail/version_windows.go @@ -0,0 +1,99 @@ +package osdetail + +import ( + "fmt" + "strings" + "sync" + + "github.com/hashicorp/go-version" + "github.com/shirou/gopsutil/host" +) + +var ( + // versionRe = regexp.MustCompile(`[0-9\.]+`) + + windowsNTVersion string + windowsNTVersionForCmp *version.Version + + fetching sync.Mutex + fetched bool +) + +// WindowsNTVersion returns the current Windows version. +func WindowsNTVersion() (string, error) { + var err error + fetching.Lock() + defer fetching.Unlock() + + if !fetched { + _, _, windowsNTVersion, err = host.PlatformInformation() + + windowsNTVersion = strings.SplitN(windowsNTVersion, " ", 2)[0] + + if err != nil { + return "", fmt.Errorf("failed to obtain Windows-Version: %s", err) + } + + windowsNTVersionForCmp, err = version.NewVersion(windowsNTVersion) + + if err != nil { + return "", fmt.Errorf("failed to parse Windows-Version %s: %s", windowsNTVersion, err) + } + + fetched = true + } + + return windowsNTVersion, err +} + +// IsAtLeastWindowsNTVersion returns whether the current WindowsNT version is at least the given version or newer. +func IsAtLeastWindowsNTVersion(v string) (bool, error) { + _, err := WindowsNTVersion() + if err != nil { + return false, err + } + + versionForCmp, err := version.NewVersion(v) + if err != nil { + return false, err + } + + return windowsNTVersionForCmp.GreaterThanOrEqual(versionForCmp), nil +} + +// IsAtLeastWindowsNTVersionWithDefault is like IsAtLeastWindowsNTVersion(), but keeps the Error and returns the default Value in Errorcase +func IsAtLeastWindowsNTVersionWithDefault(v string, defaultValue bool) bool { + val, err := IsAtLeastWindowsNTVersion(v) + if err != nil { + return defaultValue + } + return val +} + +// IsAtLeastWindowsVersion returns whether the current Windows version is at least the given version or newer. +func IsAtLeastWindowsVersion(v string) (bool, error) { + var NTVersion string + switch v { + case "7": + NTVersion = "6.1" + case "8": + NTVersion = "6.2" + case "8.1": + NTVersion = "6.3" + case "10": + NTVersion = "10" + default: + return false, fmt.Errorf("failed to compare Windows-Version: Windows %s is unknown", v) + } + + return IsAtLeastWindowsNTVersion(NTVersion) +} + +// IsAtLeastWindowsVersionWithDefault is like IsAtLeastWindowsVersion(), but keeps the Error and returns the default Value in Errorcase +func IsAtLeastWindowsVersionWithDefault(v string, defaultValue bool) bool { + val, err := IsAtLeastWindowsVersion(v) + if err != nil { + return defaultValue + } + return val +} diff --git a/base/utils/osdetail/version_windows_test.go b/base/utils/osdetail/version_windows_test.go new file mode 100644 index 000000000..c9cb409d3 --- /dev/null +++ b/base/utils/osdetail/version_windows_test.go @@ -0,0 +1,29 @@ +package osdetail + +import "testing" + +func TestWindowsNTVersion(t *testing.T) { + if str, err := WindowsNTVersion(); str == "" || err != nil { + t.Fatalf("failed to obtain windows version: %s", err) + } +} + +func TestIsAtLeastWindowsNTVersion(t *testing.T) { + ret, err := IsAtLeastWindowsNTVersion("6") + if err != nil { + t.Fatalf("failed to compare windows versions: %s", err) + } + if !ret { + t.Fatalf("WindowsNTVersion is less than 6 (Vista)") + } +} + +func TestIsAtLeastWindowsVersion(t *testing.T) { + ret, err := IsAtLeastWindowsVersion("7") + if err != nil { + t.Fatalf("failed to compare windows versions: %s", err) + } + if !ret { + t.Fatalf("WindowsVersion is less than 7") + } +} diff --git a/base/utils/renameio/LICENSE b/base/utils/renameio/LICENSE new file mode 100644 index 000000000..d64569567 --- /dev/null +++ b/base/utils/renameio/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/base/utils/renameio/README.md b/base/utils/renameio/README.md new file mode 100644 index 000000000..91e0b293e --- /dev/null +++ b/base/utils/renameio/README.md @@ -0,0 +1,55 @@ +This is a fork of the github.com/google/renameio Go package at commit 353f8196982447d8b12c64f69530e657331e3dbc. + +The inital commit of this package will carry the original package contents. +The Original License is the Apache License in Version 2.0 and the copyright of the forked package is held by Google Inc. +Any changes are recorded in the git history, which is part of this project. + +--- + +The `renameio` Go package provides a way to atomically create or replace a file or +symbolic link. + +## Atomicity vs durability + +`renameio` concerns itself *only* with atomicity, i.e. making sure applications +never see unexpected file content (a half-written file, or a 0-byte file). + +As a practical example, consider https://manpages.debian.org/: if there is a +power outage while the site is updating, we are okay with losing the manpages +which were being rendered at the time of the power outage. They will be added in +a later run of the software. We are not okay with having a manpage replaced by a +0-byte file under any circumstances, though. + +## Advantages of this package + +There are other packages for atomically replacing files, and sometimes ad-hoc +implementations can be found in programs. + +A naive approach to the problem is to create a temporary file followed by a call +to `os.Rename()`. However, there are a number of subtleties which make the +correct sequence of operations hard to identify: + +* The temporary file should be removed when an error occurs, but a remove must + not be attempted if the rename succeeded, as a new file might have been + created with the same name. This renders a throwaway `defer + os.Remove(t.Name())` insufficient; state must be kept. + +* The temporary file must be created on the same file system (same mount point) + for the rename to work, but the TMPDIR environment variable should still be + respected, e.g. to direct temporary files into a separate directory outside of + the webserver’s document root but on the same file system. + +* On POSIX operating systems, the + [`fsync`](https://manpages.debian.org/stretch/manpages-dev/fsync.2) system + call must be used to ensure that the `os.Rename()` call will not result in a + 0-length file. + +This package attempts to get all of these details right, provides an intuitive, +yet flexible API and caters to use-cases where high performance is required. + +## Disclaimer + +This is not an official Google product (experimental or otherwise), it +is just code that happens to be owned by Google. + +This project is not affiliated with the Go project. diff --git a/base/utils/renameio/doc.go b/base/utils/renameio/doc.go new file mode 100644 index 000000000..cb4c16dba --- /dev/null +++ b/base/utils/renameio/doc.go @@ -0,0 +1,7 @@ +// Package renameio provides a way to atomically create or replace a file or +// symbolic link. +// +// Caveat: this package requires the file system rename(2) implementation to be +// atomic. Notably, this is not the case when using NFS with multiple clients: +// https://stackoverflow.com/a/41396801 +package renameio diff --git a/base/utils/renameio/example_test.go b/base/utils/renameio/example_test.go new file mode 100644 index 000000000..e9ab871c1 --- /dev/null +++ b/base/utils/renameio/example_test.go @@ -0,0 +1,57 @@ +package renameio_test + +import ( + "fmt" + "log" + + "github.com/safing/portmaster/base/utils/renameio" +) + +func ExampleTempFile_justone() { //nolint:testableexamples + persist := func(temperature float64) error { + t, err := renameio.TempFile("", "/srv/www/metrics.txt") + if err != nil { + return err + } + defer func() { + _ = t.Cleanup() + }() + if _, err := fmt.Fprintf(t, "temperature_degc %f\n", temperature); err != nil { + return err + } + return t.CloseAtomicallyReplace() + } + // Thanks to the write package, a webserver exposing /srv/www never + // serves an incomplete or missing file. + if err := persist(31.2); err != nil { + log.Fatal(err) + } +} + +func ExampleTempFile_many() { //nolint:testableexamples + // Prepare for writing files to /srv/www, effectively caching calls to + // TempDir which TempFile would otherwise need to make. + dir := renameio.TempDir("/srv/www") + persist := func(temperature float64) error { + t, err := renameio.TempFile(dir, "/srv/www/metrics.txt") + if err != nil { + return err + } + defer func() { + _ = t.Cleanup() + }() + if _, err := fmt.Fprintf(t, "temperature_degc %f\n", temperature); err != nil { + return err + } + return t.CloseAtomicallyReplace() + } + + // Imagine this was an endless loop, reading temperature sensor values. + // Thanks to the write package, a webserver exposing /srv/www never + // serves an incomplete or missing file. + for { + if err := persist(31.2); err != nil { + log.Fatal(err) + } + } +} diff --git a/base/utils/renameio/symlink_test.go b/base/utils/renameio/symlink_test.go new file mode 100644 index 000000000..a3a1b48db --- /dev/null +++ b/base/utils/renameio/symlink_test.go @@ -0,0 +1,41 @@ +//go:build darwin || dragonfly || freebsd || linux || nacl || netbsd || openbsd || solaris || windows + +package renameio + +import ( + "bytes" + "os" + "path/filepath" + "testing" +) + +func TestSymlink(t *testing.T) { + t.Parallel() + + d, err := os.MkdirTemp("", "test-renameio-testsymlink") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + _ = os.RemoveAll(d) + }) + + want := []byte("Hello World") + if err := os.WriteFile(filepath.Join(d, "hello.txt"), want, 0o0600); err != nil { + t.Fatal(err) + } + + for i := 0; i < 2; i++ { + if err := Symlink("hello.txt", filepath.Join(d, "hi.txt")); err != nil { + t.Fatal(err) + } + + got, err := os.ReadFile(filepath.Join(d, "hi.txt")) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, want) { + t.Fatalf("unexpected content: got %q, want %q", string(got), string(want)) + } + } +} diff --git a/base/utils/renameio/tempfile.go b/base/utils/renameio/tempfile.go new file mode 100644 index 000000000..9397789d5 --- /dev/null +++ b/base/utils/renameio/tempfile.go @@ -0,0 +1,170 @@ +package renameio + +import ( + "errors" + "io/fs" + "os" + "path/filepath" +) + +// TempDir checks whether os.TempDir() can be used as a temporary directory for +// later atomically replacing files within dest. If no (os.TempDir() resides on +// a different mount point), dest is returned. +// +// Note that the returned value ceases to be valid once either os.TempDir() +// changes (e.g. on Linux, once the TMPDIR environment variable changes) or the +// file system is unmounted. +func TempDir(dest string) string { + return tempDir("", filepath.Join(dest, "renameio-TempDir")) +} + +func tempDir(dir, dest string) string { + if dir != "" { + return dir // caller-specified directory always wins + } + + // Chose the destination directory as temporary directory so that we + // definitely can rename the file, for which both temporary and destination + // file need to point to the same mount point. + fallback := filepath.Dir(dest) + + // The user might have overridden the os.TempDir() return value by setting + // the TMPDIR environment variable. + tmpdir := os.TempDir() + + testsrc, err := os.CreateTemp(tmpdir, "."+filepath.Base(dest)) + if err != nil { + return fallback + } + cleanup := true + defer func() { + if cleanup { + _ = os.Remove(testsrc.Name()) + } + }() + _ = testsrc.Close() + + testdest, err := os.CreateTemp(filepath.Dir(dest), "."+filepath.Base(dest)) + if err != nil { + return fallback + } + defer func() { + _ = os.Remove(testdest.Name()) + }() + _ = testdest.Close() + + if err := os.Rename(testsrc.Name(), testdest.Name()); err != nil { + return fallback + } + cleanup = false // testsrc no longer exists + return tmpdir +} + +// PendingFile is a pending temporary file, waiting to replace the destination +// path in a call to CloseAtomicallyReplace. +type PendingFile struct { + *os.File + + path string + done bool + closed bool +} + +// Cleanup is a no-op if CloseAtomicallyReplace succeeded, and otherwise closes +// and removes the temporary file. +func (t *PendingFile) Cleanup() error { + if t.done { + return nil + } + // An error occurred. Close and remove the tempfile. Errors are returned for + // reporting, there is nothing the caller can recover here. + var closeErr error + if !t.closed { + closeErr = t.Close() + } + if err := os.Remove(t.Name()); err != nil { + return err + } + return closeErr +} + +// CloseAtomicallyReplace closes the temporary file and atomically replaces +// the destination file with it, i.e., a concurrent open(2) call will either +// open the file previously located at the destination path (if any), or the +// just written file, but the file will always be present. +func (t *PendingFile) CloseAtomicallyReplace() error { + // Even on an ordered file system (e.g. ext4 with data=ordered) or file + // systems with write barriers, we cannot skip the fsync(2) call as per + // Theodore Ts'o (ext2/3/4 lead developer): + // + // > data=ordered only guarantees the avoidance of stale data (e.g., the previous + // > contents of a data block showing up after a crash, where the previous data + // > could be someone's love letters, medical records, etc.). Without the fsync(2) + // > a zero-length file is a valid and possible outcome after the rename. + if err := t.Sync(); err != nil { + return err + } + t.closed = true + if err := t.Close(); err != nil { + return err + } + if err := os.Rename(t.Name(), t.path); err != nil { + return err + } + t.done = true + return nil +} + +// TempFile wraps os.CreateTemp for the use case of atomically creating or +// replacing the destination file at path. +// +// If dir is the empty string, TempDir(filepath.Base(path)) is used. If you are +// going to write a large number of files to the same file system, store the +// result of TempDir(filepath.Base(path)) and pass it instead of the empty +// string. +// +// The file's permissions will be 0600 by default. You can change these by +// explicitly calling Chmod on the returned PendingFile. +func TempFile(dir, path string) (*PendingFile, error) { + f, err := os.CreateTemp(tempDir(dir, path), "."+filepath.Base(path)) + if err != nil { + return nil, err + } + + return &PendingFile{File: f, path: path}, nil +} + +// Symlink wraps os.Symlink, replacing an existing symlink with the same name +// atomically (os.Symlink fails when newname already exists, at least on Linux). +func Symlink(oldname, newname string) error { + // Fast path: if newname does not exist yet, we can skip the whole dance + // below. + if err := os.Symlink(oldname, newname); err == nil || !errors.Is(err, fs.ErrExist) { + return err + } + + // We need to use os.MkdirTemp, as we cannot overwrite a os.CreateTemp, + // and removing+symlinking creates a TOCTOU race. + d, err := os.MkdirTemp(filepath.Dir(newname), "."+filepath.Base(newname)) + if err != nil { + return err + } + cleanup := true + defer func() { + if cleanup { + _ = os.RemoveAll(d) + } + }() + + symlink := filepath.Join(d, "tmp.symlink") + if err := os.Symlink(oldname, symlink); err != nil { + return err + } + + if err := os.Rename(symlink, newname); err != nil { + return err + } + + cleanup = false + return os.RemoveAll(d) +} diff --git a/base/utils/renameio/tempfile_linux_test.go b/base/utils/renameio/tempfile_linux_test.go new file mode 100644 index 000000000..88ce025e0 --- /dev/null +++ b/base/utils/renameio/tempfile_linux_test.go @@ -0,0 +1,115 @@ +//go:build linux + +package renameio + +import ( + "os" + "path/filepath" + "syscall" + "testing" +) + +func TestTempDir(t *testing.T) { + t.Parallel() + + if tmpdir, ok := os.LookupEnv("TMPDIR"); ok { + t.Cleanup(func() { + _ = os.Setenv("TMPDIR", tmpdir) // restore + }) + } else { + t.Cleanup(func() { + _ = os.Unsetenv("TMPDIR") // restore + }) + } + + mount1, err := os.MkdirTemp("", "test-renameio-testtempdir1") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + _ = os.RemoveAll(mount1) + }) + + mount2, err := os.MkdirTemp("", "test-renameio-testtempdir2") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + _ = os.RemoveAll(mount2) + }) + + if err := syscall.Mount("tmpfs", mount1, "tmpfs", 0, ""); err != nil { + t.Skipf("cannot mount tmpfs on %s: %v", mount1, err) + } + t.Cleanup(func() { + _ = syscall.Unmount(mount1, 0) + }) + + if err := syscall.Mount("tmpfs", mount2, "tmpfs", 0, ""); err != nil { + t.Skipf("cannot mount tmpfs on %s: %v", mount2, err) + } + t.Cleanup(func() { + _ = syscall.Unmount(mount2, 0) + }) + + tests := []struct { + name string + dir string + path string + TMPDIR string + want string + }{ + { + name: "implicit TMPDIR", + path: filepath.Join(os.TempDir(), "foo.txt"), + want: os.TempDir(), + }, + + { + name: "explicit TMPDIR", + path: filepath.Join(mount1, "foo.txt"), + TMPDIR: mount1, + want: mount1, + }, + + { + name: "explicit unsuitable TMPDIR", + path: filepath.Join(mount1, "foo.txt"), + TMPDIR: mount2, + want: mount1, + }, + + { + name: "nonexistant TMPDIR", + path: filepath.Join(mount1, "foo.txt"), + TMPDIR: "/nonexistant", + want: mount1, + }, + + { + name: "caller-specified", + dir: "/overridden", + path: filepath.Join(mount1, "foo.txt"), + TMPDIR: "/nonexistant", + want: "/overridden", + }, + } + + for _, tt := range tests { + testCase := tt + + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + if testCase.TMPDIR == "" { + _ = os.Unsetenv("TMPDIR") + } else { + _ = os.Setenv("TMPDIR", testCase.TMPDIR) + } + + if got := tempDir(testCase.dir, testCase.path); got != testCase.want { + t.Fatalf("tempDir(%q, %q): got %q, want %q", testCase.dir, testCase.path, got, testCase.want) + } + }) + } +} diff --git a/base/utils/renameio/writefile.go b/base/utils/renameio/writefile.go new file mode 100644 index 000000000..211530255 --- /dev/null +++ b/base/utils/renameio/writefile.go @@ -0,0 +1,26 @@ +package renameio + +import "os" + +// WriteFile mirrors os.WriteFile, replacing an existing file with the same +// name atomically. +func WriteFile(filename string, data []byte, perm os.FileMode) error { + t, err := TempFile("", filename) + if err != nil { + return err + } + defer func() { + _ = t.Cleanup() + }() + + // Set permissions before writing data, in case the data is sensitive. + if err := t.Chmod(perm); err != nil { + return err + } + + if _, err := t.Write(data); err != nil { + return err + } + + return t.CloseAtomicallyReplace() +} diff --git a/base/utils/renameio/writefile_test.go b/base/utils/renameio/writefile_test.go new file mode 100644 index 000000000..eaf302b56 --- /dev/null +++ b/base/utils/renameio/writefile_test.go @@ -0,0 +1,46 @@ +//go:build darwin || dragonfly || freebsd || linux || nacl || netbsd || openbsd || solaris || windows + +package renameio + +import ( + "bytes" + "os" + "path/filepath" + "testing" +) + +func TestWriteFile(t *testing.T) { + t.Parallel() + + d, err := os.MkdirTemp("", "test-renameio-testwritefile") + if err != nil { + t.Fatal(err) + } + defer func() { + _ = os.RemoveAll(d) + }() + + filename := filepath.Join(d, "hello.sh") + + wantData := []byte("#!/bin/sh\necho \"Hello World\"\n") + wantPerm := os.FileMode(0o0600) + if err := WriteFile(filename, wantData, wantPerm); err != nil { + t.Fatal(err) + } + + gotData, err := os.ReadFile(filename) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(gotData, wantData) { + t.Errorf("got data %v, want data %v", gotData, wantData) + } + + fi, err := os.Stat(filename) + if err != nil { + t.Fatal(err) + } + if gotPerm := fi.Mode() & os.ModePerm; gotPerm != wantPerm { + t.Errorf("got permissions 0%o, want permissions 0%o", gotPerm, wantPerm) + } +} diff --git a/base/utils/safe.go b/base/utils/safe.go new file mode 100644 index 000000000..90199cfad --- /dev/null +++ b/base/utils/safe.go @@ -0,0 +1,23 @@ +package utils + +import ( + "encoding/hex" + "strings" +) + +// SafeFirst16Bytes return the first 16 bytes of the given data in safe form. +func SafeFirst16Bytes(data []byte) string { + if len(data) == 0 { + return "" + } + + return strings.TrimPrefix( + strings.SplitN(hex.Dump(data), "\n", 2)[0], + "00000000 ", + ) +} + +// SafeFirst16Chars return the first 16 characters of the given data in safe form. +func SafeFirst16Chars(s string) string { + return SafeFirst16Bytes([]byte(s)) +} diff --git a/base/utils/safe_test.go b/base/utils/safe_test.go new file mode 100644 index 000000000..43480f77b --- /dev/null +++ b/base/utils/safe_test.go @@ -0,0 +1,29 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSafeFirst16(t *testing.T) { + t.Parallel() + + assert.Equal(t, + "47 6f 20 69 73 20 61 6e 20 6f 70 65 6e 20 73 6f |Go is an open so|", + SafeFirst16Bytes([]byte("Go is an open source programming language.")), + ) + assert.Equal(t, + "47 6f 20 69 73 20 61 6e 20 6f 70 65 6e 20 73 6f |Go is an open so|", + SafeFirst16Chars("Go is an open source programming language."), + ) + + assert.Equal(t, + "", + SafeFirst16Bytes(nil), + ) + assert.Equal(t, + "", + SafeFirst16Chars(""), + ) +} diff --git a/base/utils/slices.go b/base/utils/slices.go new file mode 100644 index 000000000..f185b3bce --- /dev/null +++ b/base/utils/slices.go @@ -0,0 +1,52 @@ +package utils + +// IndexOfString returns the index of given string and -1 if its not part of the slice. +func IndexOfString(a []string, s string) int { + for i, entry := range a { + if entry == s { + return i + } + } + return -1 +} + +// StringInSlice returns whether the given string is in the string slice. +func StringInSlice(a []string, s string) bool { + return IndexOfString(a, s) >= 0 +} + +// RemoveFromStringSlice removes the given string from the slice and returns a new slice. +func RemoveFromStringSlice(a []string, s string) []string { + i := IndexOfString(a, s) + if i > 0 { + a = append(a[:i], a[i+1:]...) + } + return a +} + +// DuplicateStrings returns a new copy of the given string slice. +func DuplicateStrings(a []string) []string { + b := make([]string, len(a)) + copy(b, a) + return b +} + +// StringSliceEqual returns whether the given string slices are equal. +func StringSliceEqual(a []string, b []string) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if v != b[i] { + return false + } + } + return true +} + +// DuplicateBytes returns a new copy of the given byte slice. +func DuplicateBytes(a []byte) []byte { + b := make([]byte, len(a)) + copy(b, a) + return b +} diff --git a/base/utils/slices_test.go b/base/utils/slices_test.go new file mode 100644 index 000000000..50ac49461 --- /dev/null +++ b/base/utils/slices_test.go @@ -0,0 +1,91 @@ +package utils + +import ( + "bytes" + "testing" +) + +var ( + stringTestSlice = []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"} + stringTestSlice2 = []string{"a", "x", "x", "x", "x", "x", "x", "x", "x", "j"} + stringTestSlice3 = []string{"a", "x"} + byteTestSlice = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} +) + +func TestStringInSlice(t *testing.T) { + t.Parallel() + + if !StringInSlice(stringTestSlice, "a") { + t.Fatal("string reported not in slice (1), but it is") + } + if !StringInSlice(stringTestSlice, "d") { + t.Fatal("string reported not in slice (2), but it is") + } + if !StringInSlice(stringTestSlice, "j") { + t.Fatal("string reported not in slice (3), but it is") + } + + if StringInSlice(stringTestSlice, "0") { + t.Fatal("string reported in slice (1), but is not") + } + if StringInSlice(stringTestSlice, "x") { + t.Fatal("string reported in slice (2), but is not") + } + if StringInSlice(stringTestSlice, "k") { + t.Fatal("string reported in slice (3), but is not") + } +} + +func TestRemoveFromStringSlice(t *testing.T) { + t.Parallel() + + test1 := DuplicateStrings(stringTestSlice) + test1 = RemoveFromStringSlice(test1, "b") + if StringInSlice(test1, "b") { + t.Fatal("string reported in slice, but was removed") + } + if len(test1) != len(stringTestSlice)-1 { + t.Fatalf("new string slice length not as expected: is %d, should be %d\nnew slice is %v", len(test1), len(stringTestSlice)-1, test1) + } + RemoveFromStringSlice(test1, "b") +} + +func TestDuplicateStrings(t *testing.T) { + t.Parallel() + + a := DuplicateStrings(stringTestSlice) + if !StringSliceEqual(a, stringTestSlice) { + t.Fatal("copied string slice is not equal") + } + a[0] = "x" + if StringSliceEqual(a, stringTestSlice) { + t.Fatal("copied string slice is not a real copy") + } +} + +func TestStringSliceEqual(t *testing.T) { + t.Parallel() + + if !StringSliceEqual(stringTestSlice, stringTestSlice) { + t.Fatal("strings are equal, but are reported as not") + } + if StringSliceEqual(stringTestSlice, stringTestSlice2) { + t.Fatal("strings are not equal (1), but are reported as equal") + } + if StringSliceEqual(stringTestSlice, stringTestSlice3) { + t.Fatal("strings are not equal (1), but are reported as equal") + } +} + +func TestDuplicateBytes(t *testing.T) { + t.Parallel() + + a := DuplicateBytes(byteTestSlice) + if !bytes.Equal(a, byteTestSlice) { + t.Fatal("copied bytes slice is not equal") + } + a[0] = 0xff + if bytes.Equal(a, byteTestSlice) { + t.Fatal("copied bytes slice is not a real copy") + } +} diff --git a/base/utils/stablepool.go b/base/utils/stablepool.go new file mode 100644 index 000000000..147e96c6e --- /dev/null +++ b/base/utils/stablepool.go @@ -0,0 +1,118 @@ +package utils + +import "sync" + +// A StablePool is a drop-in replacement for sync.Pool that is slower, but +// predictable. +// A StablePool is a set of temporary objects that may be individually saved and +// retrieved. +// +// In contrast to sync.Pool, items are not removed automatically. Every item +// will be returned at some point. Items are returned in a FIFO manner in order +// to evenly distribute usage of a set of items. +// +// A StablePool is safe for use by multiple goroutines simultaneously and must +// not be copied after first use. +type StablePool struct { + lock sync.Mutex + + pool []interface{} + cnt int + getIndex int + putIndex int + + // New optionally specifies a function to generate + // a value when Get would otherwise return nil. + // It may not be changed concurrently with calls to Get. + New func() interface{} +} + +// Put adds x to the pool. +func (p *StablePool) Put(x interface{}) { + if x == nil { + return + } + + p.lock.Lock() + defer p.lock.Unlock() + + // check if pool is full (or unitialized) + if p.cnt == len(p.pool) { + p.pool = append(p.pool, x) + p.cnt++ + p.putIndex = p.cnt + return + } + + // correct putIndex + p.putIndex %= len(p.pool) + + // iterate the whole pool once to find a free spot + stopAt := p.putIndex - 1 + for i := p.putIndex; i != stopAt; i = (i + 1) % len(p.pool) { + if p.pool[i] == nil { + p.pool[i] = x + p.cnt++ + p.putIndex = i + 1 + return + } + } +} + +// Get returns the next item from the Pool, removes it from the Pool, and +// returns it to the caller. +// In contrast to sync.Pool, Get never ignores the pool. +// Callers should not assume any relation between values passed to Put and +// the values returned by Get. +// +// If Get would otherwise return nil and p.New is non-nil, Get returns +// the result of calling p.New. +func (p *StablePool) Get() interface{} { + p.lock.Lock() + defer p.lock.Unlock() + + // check if pool is empty + if p.cnt == 0 { + if p.New != nil { + return p.New() + } + return nil + } + + // correct getIndex + p.getIndex %= len(p.pool) + + // iterate the whole pool to find an item + stopAt := p.getIndex - 1 + for i := p.getIndex; i != stopAt; i = (i + 1) % len(p.pool) { + if p.pool[i] != nil { + x := p.pool[i] + p.pool[i] = nil + p.cnt-- + p.getIndex = i + 1 + return x + } + } + + // if we ever get here, return a new item + if p.New != nil { + return p.New() + } + return nil +} + +// Size returns the amount of items the pool currently holds. +func (p *StablePool) Size() int { + p.lock.Lock() + defer p.lock.Unlock() + + return p.cnt +} + +// Max returns the amount of items the pool held at maximum. +func (p *StablePool) Max() int { + p.lock.Lock() + defer p.lock.Unlock() + + return len(p.pool) +} diff --git a/base/utils/stablepool_test.go b/base/utils/stablepool_test.go new file mode 100644 index 000000000..0f8ed7262 --- /dev/null +++ b/base/utils/stablepool_test.go @@ -0,0 +1,120 @@ +package utils + +import ( + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestStablePoolRealWorld(t *testing.T) { + t.Parallel() + // "real world" simulation + + cnt := 0 + testPool := &StablePool{ + New: func() interface{} { + cnt++ + return cnt + }, + } + var testWg sync.WaitGroup + var testWorkerWg sync.WaitGroup + + // for i := 0; i < 100; i++ { + // cnt++ + // testPool.Put(cnt) + // } + for i := 0; i < 100; i++ { + // block round + testWg.Add(1) + // add workers + testWorkerWg.Add(100) + for j := 0; j < 100; j++ { + k := j + go func() { + // wait for round to start + testWg.Wait() + // get value + x := testPool.Get() + // fmt.Println(x) + // "work" + time.Sleep(5 * time.Microsecond) + // re-insert 99% + if k%100 > 0 { + testPool.Put(x) + } + // mark as finished + testWorkerWg.Done() + }() + } + // start round + testWg.Done() + // wait for round to finish + testWorkerWg.Wait() + } + t.Logf("real world simulation: cnt=%d p.cnt=%d p.max=%d\n", cnt, testPool.Size(), testPool.Max()) + assert.GreaterOrEqual(t, 200, cnt, "should not use more than 200 values") + assert.GreaterOrEqual(t, 100, testPool.Max(), "pool should have at most this max size") + + // optimal usage test + + optPool := &StablePool{} + for i := 0; i < 1000; i++ { + for j := 0; j < 100; j++ { + optPool.Put(j) + } + for k := 0; k < 100; k++ { + assert.Equal(t, k, optPool.Get(), "should match") + } + } + assert.Equal(t, 100, optPool.Max(), "pool should have exactly this max size") +} + +func TestStablePoolFuzzing(t *testing.T) { + t.Parallel() + // fuzzing test + + fuzzPool := &StablePool{} + var fuzzWg sync.WaitGroup + var fuzzWorkerWg sync.WaitGroup + // start goroutines and wait + fuzzWg.Add(1) + for i := 0; i < 1000; i++ { + fuzzWorkerWg.Add(2) + j := i + go func() { + fuzzWg.Wait() + fuzzPool.Put(j) + fuzzWorkerWg.Done() + }() + go func() { + fuzzWg.Wait() + fmt.Print(fuzzPool.Get()) + fuzzWorkerWg.Done() + }() + } + // kick off + fuzzWg.Done() + // wait for all to finish + fuzzWorkerWg.Wait() +} + +func TestStablePoolBreaking(t *testing.T) { + t.Parallel() + // try to break it + + breakPool := &StablePool{} + for i := 0; i < 10; i++ { + for j := 0; j < 100; j++ { + breakPool.Put(nil) + breakPool.Put(j) + breakPool.Put(nil) + } + for k := 0; k < 100; k++ { + assert.Equal(t, k, breakPool.Get(), "should match") + } + } +} diff --git a/base/utils/structure.go b/base/utils/structure.go new file mode 100644 index 000000000..5a50d97e0 --- /dev/null +++ b/base/utils/structure.go @@ -0,0 +1,139 @@ +package utils + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "sync" +) + +// DirStructure represents a directory structure with permissions that should be enforced. +type DirStructure struct { + sync.Mutex + + Path string + Dir string + Perm os.FileMode + Parent *DirStructure + Children map[string]*DirStructure +} + +// NewDirStructure returns a new DirStructure. +func NewDirStructure(path string, perm os.FileMode) *DirStructure { + return &DirStructure{ + Path: path, + Perm: perm, + Children: make(map[string]*DirStructure), + } +} + +// ChildDir adds a new child DirStructure and returns it. Should the child already exist, the existing child is returned and the permissions are updated. +func (ds *DirStructure) ChildDir(dirName string, perm os.FileMode) (child *DirStructure) { + ds.Lock() + defer ds.Unlock() + + // if exists, update + child, ok := ds.Children[dirName] + if ok { + child.Perm = perm + return child + } + + // create new + newDir := &DirStructure{ + Path: filepath.Join(ds.Path, dirName), + Dir: dirName, + Perm: perm, + Parent: ds, + Children: make(map[string]*DirStructure), + } + ds.Children[dirName] = newDir + return newDir +} + +// Ensure ensures that the specified directory structure (from the first parent on) exists. +func (ds *DirStructure) Ensure() error { + return ds.EnsureAbsPath(ds.Path) +} + +// EnsureRelPath ensures that the specified directory structure (from the first parent on) and the given relative path (to the DirStructure) exists. +func (ds *DirStructure) EnsureRelPath(dirPath string) error { + return ds.EnsureAbsPath(filepath.Join(ds.Path, dirPath)) +} + +// EnsureRelDir ensures that the specified directory structure (from the first parent on) and the given relative path (to the DirStructure) exists. +func (ds *DirStructure) EnsureRelDir(dirNames ...string) error { + return ds.EnsureAbsPath(filepath.Join(append([]string{ds.Path}, dirNames...)...)) +} + +// EnsureAbsPath ensures that the specified directory structure (from the first parent on) and the given absolute path exists. +// If the given path is outside the DirStructure, an error will be returned. +func (ds *DirStructure) EnsureAbsPath(dirPath string) error { + // always start at the top + if ds.Parent != nil { + return ds.Parent.EnsureAbsPath(dirPath) + } + + // check if root + if dirPath == ds.Path { + return ds.ensure(nil) + } + + // check scope + slashedPath := ds.Path + // add slash to end + if !strings.HasSuffix(slashedPath, string(filepath.Separator)) { + slashedPath += string(filepath.Separator) + } + // check if given path is in scope + if !strings.HasPrefix(dirPath, slashedPath) { + return fmt.Errorf(`path "%s" is outside of DirStructure scope`, dirPath) + } + + // get relative path + relPath, err := filepath.Rel(ds.Path, dirPath) + if err != nil { + return fmt.Errorf("failed to get relative path: %w", err) + } + + // split to path elements + pathDirs := strings.Split(filepath.ToSlash(relPath), "/") + + // start checking + return ds.ensure(pathDirs) +} + +func (ds *DirStructure) ensure(pathDirs []string) error { + ds.Lock() + defer ds.Unlock() + + // check current dir + err := EnsureDirectory(ds.Path, ds.Perm) + if err != nil { + return err + } + + if len(pathDirs) == 0 { + // we reached the end! + return nil + } + + child, ok := ds.Children[pathDirs[0]] + if !ok { + // we have reached the end of the defined dir structure + // ensure all remaining dirs + dirPath := ds.Path + for _, dir := range pathDirs { + dirPath = filepath.Join(dirPath, dir) + err := EnsureDirectory(dirPath, ds.Perm) + if err != nil { + return err + } + } + return nil + } + + // we got a child, continue + return child.ensure(pathDirs[1:]) +} diff --git a/base/utils/structure_test.go b/base/utils/structure_test.go new file mode 100644 index 000000000..2acfebd21 --- /dev/null +++ b/base/utils/structure_test.go @@ -0,0 +1,73 @@ +//go:build !windows + +package utils + +import ( + "fmt" + "os" + "path/filepath" + "strings" +) + +func ExampleDirStructure() { + // output: + // / [755] + // /repo [777] + // /repo/b [707] + // /repo/b/c [750] + // /repo/b/d [707] + // /repo/b/d/e [707] + // /repo/b/d/f [707] + // /repo/b/d/f/g [707] + // /repo/b/d/f/g/h [707] + // /secret [700] + + basePath, err := os.MkdirTemp("", "") + if err != nil { + fmt.Println(err) + return + } + + ds := NewDirStructure(basePath, 0o0755) + secret := ds.ChildDir("secret", 0o0700) + repo := ds.ChildDir("repo", 0o0777) + _ = repo.ChildDir("a", 0o0700) + b := repo.ChildDir("b", 0o0707) + c := b.ChildDir("c", 0o0750) + + err = ds.Ensure() + if err != nil { + fmt.Println(err) + } + + err = c.Ensure() + if err != nil { + fmt.Println(err) + } + + err = secret.Ensure() + if err != nil { + fmt.Println(err) + } + + err = b.EnsureRelDir("d", "e") + if err != nil { + fmt.Println(err) + } + + err = b.EnsureRelPath("d/f/g/h") + if err != nil { + fmt.Println(err) + } + + _ = filepath.Walk(basePath, func(path string, info os.FileInfo, err error) error { + if err == nil { + dir := strings.TrimPrefix(path, basePath) + if dir == "" { + dir = "/" + } + fmt.Printf("%s [%o]\n", dir, info.Mode().Perm()) + } + return nil + }) +} diff --git a/base/utils/uuid.go b/base/utils/uuid.go new file mode 100644 index 000000000..bf437ba41 --- /dev/null +++ b/base/utils/uuid.go @@ -0,0 +1,45 @@ +package utils + +import ( + "encoding/binary" + "time" + + "github.com/gofrs/uuid" +) + +var ( + constantUUID = uuid.Must(uuid.FromString("e8dba9f7-21e2-4c82-96cb-6586922c6422")) + instanceUUID = RandomUUID("instance") +) + +// RandomUUID returns a new random UUID with optionally provided ns. +func RandomUUID(ns string) uuid.UUID { + randUUID, err := uuid.NewV4() + switch { + case err != nil: + // fallback + // should practically never happen + return uuid.NewV5(uuidFromTime(), ns) + case ns != "": + // mix ns into the UUID + return uuid.NewV5(randUUID, ns) + default: + return randUUID + } +} + +// DerivedUUID returns a new UUID that is derived from the input only, and therefore is always reproducible. +func DerivedUUID(input string) uuid.UUID { + return uuid.NewV5(constantUUID, input) +} + +// DerivedInstanceUUID returns a new UUID that is derived from the input, but is unique per instance (execution) and therefore is only reproducible with the same process. +func DerivedInstanceUUID(input string) uuid.UUID { + return uuid.NewV5(instanceUUID, input) +} + +func uuidFromTime() uuid.UUID { + var timeUUID uuid.UUID + binary.LittleEndian.PutUint64(timeUUID[:], uint64(time.Now().UnixNano())) + return timeUUID +} diff --git a/base/utils/uuid_test.go b/base/utils/uuid_test.go new file mode 100644 index 000000000..6985b81eb --- /dev/null +++ b/base/utils/uuid_test.go @@ -0,0 +1,71 @@ +package utils + +import ( + "testing" + "time" + + "github.com/gofrs/uuid" +) + +func TestUUID(t *testing.T) { + t.Parallel() + + // check randomness + a := RandomUUID("") + a2 := RandomUUID("") + if a.String() == a2.String() { + t.Error("should not match") + } + + // check with input + b := RandomUUID("b") + b2 := RandomUUID("b") + if b.String() == b2.String() { + t.Error("should not match") + } + + // check with long input + c := RandomUUID("TG8UkxS+4rVrDxHtDAaNab1CBpygzmX1g5mJA37jbQ5q2uE4rVrDxHtDAaNab1CBpygzmX1g5mJA37jbQ5q2uE") + c2 := RandomUUID("TG8UkxS+4rVrDxHtDAaNab1CBpygzmX1g5mJA37jbQ5q2uE4rVrDxHtDAaNab1CBpygzmX1g5mJA37jbQ5q2uE") + if c.String() == c2.String() { + t.Error("should not match") + } + + // check for nanosecond precision + d := uuidFromTime() + time.Sleep(2 * time.Nanosecond) + d2 := uuidFromTime() + if d.String() == d2.String() { + t.Error("should not match") + } + + // check mixing + timeUUID := uuidFromTime() + e := uuid.NewV5(timeUUID, "e") + e2 := uuid.NewV5(timeUUID, "e2") + if e.String() == e2.String() { + t.Error("should not match") + } + + // check deriving + f := DerivedUUID("f") + f2 := DerivedUUID("f") + f3 := DerivedUUID("f3") + if f.String() != f2.String() { + t.Error("should match") + } + if f.String() == f3.String() { + t.Error("should not match") + } + + // check instance deriving + g := DerivedInstanceUUID("g") + g2 := DerivedInstanceUUID("g") + g3 := DerivedInstanceUUID("g3") + if g.String() != g2.String() { + t.Error("should match") + } + if g.String() == g3.String() { + t.Error("should not match") + } +} diff --git a/cmds/hub/build b/cmds/hub/build index 055874ef0..c95f6e738 100755 --- a/cmds/hub/build +++ b/cmds/hub/build @@ -56,5 +56,5 @@ echo "This information is useful for debugging and license compliance." echo "Run the compiled binary with the -version flag to see the information included." # build -BUILD_PATH="github.com/safing/portbase/info" +BUILD_PATH="github.com/safing/portmaster/base/info" go build $DEV -ldflags "-X ${BUILD_PATH}.commit=${BUILD_COMMIT} -X ${BUILD_PATH}.buildOptions=${BUILD_BUILDOPTIONS} -X ${BUILD_PATH}.buildUser=${BUILD_USER} -X ${BUILD_PATH}.buildHost=${BUILD_HOST} -X ${BUILD_PATH}.buildDate=${BUILD_DATE} -X ${BUILD_PATH}.buildSource=${BUILD_SOURCE}" $* diff --git a/cmds/hub/main.go b/cmds/hub/main.go index 74c7e3162..4b67299f7 100644 --- a/cmds/hub/main.go +++ b/cmds/hub/main.go @@ -6,10 +6,10 @@ import ( "os" "runtime" - "github.com/safing/portbase/info" - "github.com/safing/portbase/metrics" - "github.com/safing/portbase/modules" - "github.com/safing/portbase/run" + "github.com/safing/portmaster/base/info" + "github.com/safing/portmaster/base/metrics" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/run" _ "github.com/safing/portmaster/service/core/base" _ "github.com/safing/portmaster/service/ui" "github.com/safing/portmaster/service/updates" diff --git a/cmds/notifier/http_api.go b/cmds/notifier/http_api.go index 356b81cd4..7a68349a9 100644 --- a/cmds/notifier/http_api.go +++ b/cmds/notifier/http_api.go @@ -8,7 +8,7 @@ import ( "strings" "time" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" ) const ( diff --git a/cmds/notifier/main.go b/cmds/notifier/main.go index a109d01aa..ef4f0e603 100644 --- a/cmds/notifier/main.go +++ b/cmds/notifier/main.go @@ -17,13 +17,13 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/api/client" - "github.com/safing/portbase/dataroot" - "github.com/safing/portbase/info" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" - "github.com/safing/portbase/updater" - "github.com/safing/portbase/utils" + "github.com/safing/portmaster/base/api/client" + "github.com/safing/portmaster/base/dataroot" + "github.com/safing/portmaster/base/info" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/updater" + "github.com/safing/portmaster/base/utils" "github.com/safing/portmaster/service/updates/helper" ) diff --git a/cmds/notifier/notification.go b/cmds/notifier/notification.go index 075dba83f..0f6ded8da 100644 --- a/cmds/notifier/notification.go +++ b/cmds/notifier/notification.go @@ -3,7 +3,7 @@ package main import ( "fmt" - pbnotify "github.com/safing/portbase/notifications" + pbnotify "github.com/safing/portmaster/base/notifications" ) // Notification represents a notification that is to be delivered to the user. diff --git a/cmds/notifier/notify.go b/cmds/notifier/notify.go index 2286dff61..d78c36109 100644 --- a/cmds/notifier/notify.go +++ b/cmds/notifier/notify.go @@ -6,10 +6,10 @@ import ( "sync" "time" - "github.com/safing/portbase/api/client" - "github.com/safing/portbase/formats/dsd" - "github.com/safing/portbase/log" - pbnotify "github.com/safing/portbase/notifications" + "github.com/safing/portmaster/base/api/client" + "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/base/log" + pbnotify "github.com/safing/portmaster/base/notifications" ) const ( diff --git a/cmds/notifier/notify_linux.go b/cmds/notifier/notify_linux.go index ba3f638e5..80cc8e15f 100644 --- a/cmds/notifier/notify_linux.go +++ b/cmds/notifier/notify_linux.go @@ -7,7 +7,7 @@ import ( notify "github.com/dhaavi/go-notify" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" ) type NotificationID uint32 diff --git a/cmds/notifier/notify_windows.go b/cmds/notifier/notify_windows.go index abb56be0a..98cf987af 100644 --- a/cmds/notifier/notify_windows.go +++ b/cmds/notifier/notify_windows.go @@ -4,7 +4,7 @@ import ( "fmt" "sync" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/cmds/notifier/wintoast" "github.com/safing/portmaster/service/updates/helper" ) diff --git a/cmds/notifier/shutdown.go b/cmds/notifier/shutdown.go index f943938db..70b2e6d84 100644 --- a/cmds/notifier/shutdown.go +++ b/cmds/notifier/shutdown.go @@ -1,8 +1,8 @@ package main import ( - "github.com/safing/portbase/api/client" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/api/client" + "github.com/safing/portmaster/base/log" ) func startShutdownEventListener() { diff --git a/cmds/notifier/spn.go b/cmds/notifier/spn.go index d313716bf..30fa18f81 100644 --- a/cmds/notifier/spn.go +++ b/cmds/notifier/spn.go @@ -6,9 +6,9 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/api/client" - "github.com/safing/portbase/formats/dsd" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/api/client" + "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/base/log" ) const ( diff --git a/cmds/notifier/subsystems.go b/cmds/notifier/subsystems.go index 8444bf13d..38390cfd2 100644 --- a/cmds/notifier/subsystems.go +++ b/cmds/notifier/subsystems.go @@ -3,9 +3,9 @@ package main import ( "sync" - "github.com/safing/portbase/api/client" - "github.com/safing/portbase/formats/dsd" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/api/client" + "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/base/log" ) const ( diff --git a/cmds/notifier/tray.go b/cmds/notifier/tray.go index 4044d4f74..abdf48d5e 100644 --- a/cmds/notifier/tray.go +++ b/cmds/notifier/tray.go @@ -11,8 +11,8 @@ import ( "fyne.io/systray" - "github.com/safing/portbase/log" icons "github.com/safing/portmaster/assets" + "github.com/safing/portmaster/base/log" ) const ( diff --git a/cmds/observation-hub/apprise.go b/cmds/observation-hub/apprise.go index c7df3c19e..bb16685cd 100644 --- a/cmds/observation-hub/apprise.go +++ b/cmds/observation-hub/apprise.go @@ -12,9 +12,9 @@ import ( "text/template" "time" - "github.com/safing/portbase/apprise" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/apprise" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/service/intel/geoip" ) diff --git a/cmds/observation-hub/build b/cmds/observation-hub/build index 055874ef0..c95f6e738 100755 --- a/cmds/observation-hub/build +++ b/cmds/observation-hub/build @@ -56,5 +56,5 @@ echo "This information is useful for debugging and license compliance." echo "Run the compiled binary with the -version flag to see the information included." # build -BUILD_PATH="github.com/safing/portbase/info" +BUILD_PATH="github.com/safing/portmaster/base/info" go build $DEV -ldflags "-X ${BUILD_PATH}.commit=${BUILD_COMMIT} -X ${BUILD_PATH}.buildOptions=${BUILD_BUILDOPTIONS} -X ${BUILD_PATH}.buildUser=${BUILD_USER} -X ${BUILD_PATH}.buildHost=${BUILD_HOST} -X ${BUILD_PATH}.buildDate=${BUILD_DATE} -X ${BUILD_PATH}.buildSource=${BUILD_SOURCE}" $* diff --git a/cmds/observation-hub/main.go b/cmds/observation-hub/main.go index c69786c91..dfa7c582e 100644 --- a/cmds/observation-hub/main.go +++ b/cmds/observation-hub/main.go @@ -5,11 +5,11 @@ import ( "os" "runtime" - "github.com/safing/portbase/api" - "github.com/safing/portbase/info" - "github.com/safing/portbase/metrics" - "github.com/safing/portbase/modules" - "github.com/safing/portbase/run" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/info" + "github.com/safing/portmaster/base/metrics" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/run" "github.com/safing/portmaster/service/updates" "github.com/safing/portmaster/service/updates/helper" "github.com/safing/portmaster/spn/captain" diff --git a/cmds/observation-hub/observe.go b/cmds/observation-hub/observe.go index 371b86923..ca4e64038 100644 --- a/cmds/observation-hub/observe.go +++ b/cmds/observation-hub/observe.go @@ -12,10 +12,10 @@ import ( diff "github.com/r3labs/diff/v3" "golang.org/x/exp/slices" - "github.com/safing/portbase/database" - "github.com/safing/portbase/database/query" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/spn/captain" "github.com/safing/portmaster/spn/navigator" ) diff --git a/cmds/portmaster-core/build b/cmds/portmaster-core/build index 6f6bb113c..63b410795 100755 --- a/cmds/portmaster-core/build +++ b/cmds/portmaster-core/build @@ -10,5 +10,5 @@ BUILD_TIME=$(date -u "+%Y-%m-%dT%H:%M:%SZ" || echo "unknown") # Build export CGO_ENABLED=0 -BUILD_PATH="github.com/safing/portbase/info" -go build -ldflags "-X github.com/safing/portbase/info.version=${VERSION} -X github.com/safing/portbase/info.buildSource=${SOURCE} -X github.com/safing/portbase/info.buildTime=${BUILD_TIME}" "$@" +BUILD_PATH="github.com/safing/portmaster/base/info" +go build -ldflags "-X github.com/safing/portmaster/base/info.version=${VERSION} -X github.com/safing/portmaster/base/info.buildSource=${SOURCE} -X github.com/safing/portmaster/base/info.buildTime=${BUILD_TIME}" "$@" diff --git a/cmds/portmaster-core/main.go b/cmds/portmaster-core/main.go index 687dcf4e2..764e8ef33 100644 --- a/cmds/portmaster-core/main.go +++ b/cmds/portmaster-core/main.go @@ -6,15 +6,15 @@ import ( "os" "runtime" - "github.com/safing/portbase/info" - "github.com/safing/portbase/log" - "github.com/safing/portbase/metrics" - "github.com/safing/portbase/run" + "github.com/safing/portmaster/base/info" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/metrics" + "github.com/safing/portmaster/base/run" "github.com/safing/portmaster/service/updates" "github.com/safing/portmaster/spn/conf" // Include packages here. - _ "github.com/safing/portbase/modules/subsystems" + _ "github.com/safing/portmaster/base/modules/subsystems" _ "github.com/safing/portmaster/service/core" _ "github.com/safing/portmaster/service/firewall" _ "github.com/safing/portmaster/service/nameserver" diff --git a/cmds/portmaster-start/build b/cmds/portmaster-start/build index a95b04df2..38c552f5f 100755 --- a/cmds/portmaster-start/build +++ b/cmds/portmaster-start/build @@ -73,5 +73,5 @@ echo "This information is useful for debugging and license compliance." echo "Run the compiled binary with the -version flag to see the information included." # build -BUILD_PATH="github.com/safing/portbase/info" +BUILD_PATH="github.com/safing/portmaster/base/info" go build -ldflags "$EXTRA_LD_FLAGS -X ${BUILD_PATH}.commit=${BUILD_COMMIT} -X ${BUILD_PATH}.buildOptions=${BUILD_BUILDOPTIONS} -X ${BUILD_PATH}.buildUser=${BUILD_USER} -X ${BUILD_PATH}.buildHost=${BUILD_HOST} -X ${BUILD_PATH}.buildDate=${BUILD_DATE} -X ${BUILD_PATH}.buildSource=${BUILD_SOURCE}" "$@" diff --git a/cmds/portmaster-start/logs.go b/cmds/portmaster-start/logs.go index a7ab4d020..d9b02c616 100644 --- a/cmds/portmaster-start/logs.go +++ b/cmds/portmaster-start/logs.go @@ -10,10 +10,10 @@ import ( "github.com/spf13/cobra" - "github.com/safing/portbase/container" - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/formats/dsd" - "github.com/safing/portbase/info" + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/base/info" ) func initializeLogFile(logFilePath string, identifier string, version string) *os.File { diff --git a/cmds/portmaster-start/main.go b/cmds/portmaster-start/main.go index fd2e5a328..1762f48b5 100644 --- a/cmds/portmaster-start/main.go +++ b/cmds/portmaster-start/main.go @@ -15,11 +15,11 @@ import ( "github.com/spf13/cobra" - "github.com/safing/portbase/dataroot" - "github.com/safing/portbase/info" - portlog "github.com/safing/portbase/log" - "github.com/safing/portbase/updater" - "github.com/safing/portbase/utils" + "github.com/safing/portmaster/base/dataroot" + "github.com/safing/portmaster/base/info" + portlog "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/updater" + "github.com/safing/portmaster/base/utils" "github.com/safing/portmaster/service/updates/helper" ) diff --git a/cmds/portmaster-start/update.go b/cmds/portmaster-start/update.go index bbcec8604..544047e1a 100644 --- a/cmds/portmaster-start/update.go +++ b/cmds/portmaster-start/update.go @@ -8,8 +8,8 @@ import ( "github.com/spf13/cobra" - portlog "github.com/safing/portbase/log" - "github.com/safing/portbase/updater" + portlog "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/updater" "github.com/safing/portmaster/service/updates/helper" ) diff --git a/cmds/portmaster-start/verify.go b/cmds/portmaster-start/verify.go index 7fb7be08a..9d63c51a9 100644 --- a/cmds/portmaster-start/verify.go +++ b/cmds/portmaster-start/verify.go @@ -13,8 +13,8 @@ import ( "github.com/safing/jess" "github.com/safing/jess/filesig" - portlog "github.com/safing/portbase/log" - "github.com/safing/portbase/updater" + portlog "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/updater" "github.com/safing/portmaster/service/updates/helper" ) diff --git a/cmds/portmaster-start/version.go b/cmds/portmaster-start/version.go index fa82884b7..6c28362aa 100644 --- a/cmds/portmaster-start/version.go +++ b/cmds/portmaster-start/version.go @@ -10,7 +10,7 @@ import ( "github.com/spf13/cobra" - "github.com/safing/portbase/info" + "github.com/safing/portmaster/base/info" ) var ( diff --git a/cmds/testsuite/db.go b/cmds/testsuite/db.go index 848e4d891..b23a6fd72 100644 --- a/cmds/testsuite/db.go +++ b/cmds/testsuite/db.go @@ -1,8 +1,8 @@ package main import ( - "github.com/safing/portbase/database" - _ "github.com/safing/portbase/database/storage/hashmap" + "github.com/safing/portmaster/base/database" + _ "github.com/safing/portmaster/base/database/storage/hashmap" ) func setupDatabases(path string) error { diff --git a/cmds/trafficgen/main.go b/cmds/trafficgen/main.go index e57b6d1da..efcfd390a 100644 --- a/cmds/trafficgen/main.go +++ b/cmds/trafficgen/main.go @@ -9,7 +9,7 @@ import ( "github.com/miekg/dns" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" ) const dnsResolver = "1.1.1.1:53" diff --git a/cmds/updatemgr/main.go b/cmds/updatemgr/main.go index ffd66b282..acd9a0d47 100644 --- a/cmds/updatemgr/main.go +++ b/cmds/updatemgr/main.go @@ -7,8 +7,8 @@ import ( "github.com/spf13/cobra" - "github.com/safing/portbase/updater" - "github.com/safing/portbase/utils" + "github.com/safing/portmaster/base/updater" + "github.com/safing/portmaster/base/utils" ) var ( diff --git a/cmds/updatemgr/purge.go b/cmds/updatemgr/purge.go index cfdf3711f..d5b456ee9 100644 --- a/cmds/updatemgr/purge.go +++ b/cmds/updatemgr/purge.go @@ -5,7 +5,7 @@ import ( "github.com/spf13/cobra" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" ) func init() { diff --git a/cmds/updatemgr/release.go b/cmds/updatemgr/release.go index e32b020c2..0f5d596e7 100644 --- a/cmds/updatemgr/release.go +++ b/cmds/updatemgr/release.go @@ -10,7 +10,7 @@ import ( "github.com/spf13/cobra" - "github.com/safing/portbase/updater" + "github.com/safing/portmaster/base/updater" ) var ( diff --git a/cmds/winkext-test/main.go b/cmds/winkext-test/main.go index 9b17b1a33..f67382b55 100644 --- a/cmds/winkext-test/main.go +++ b/cmds/winkext-test/main.go @@ -12,7 +12,7 @@ import ( "path/filepath" "syscall" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/firewall/interception/windowskext" "github.com/safing/portmaster/service/network/packet" ) diff --git a/service/broadcasts/api.go b/service/broadcasts/api.go index f855ddfc3..4bee5195c 100644 --- a/service/broadcasts/api.go +++ b/service/broadcasts/api.go @@ -7,9 +7,9 @@ import ( "net/http" "strings" - "github.com/safing/portbase/api" - "github.com/safing/portbase/database" - "github.com/safing/portbase/database/accessor" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/accessor" ) func registerAPIEndpoints() error { diff --git a/service/broadcasts/data.go b/service/broadcasts/data.go index 22faf4581..2b59e4e63 100644 --- a/service/broadcasts/data.go +++ b/service/broadcasts/data.go @@ -4,7 +4,7 @@ import ( "strconv" "time" - "github.com/safing/portbase/config" + "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/service/intel/geoip" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/updates" diff --git a/service/broadcasts/install_info.go b/service/broadcasts/install_info.go index 2f667a173..429699463 100644 --- a/service/broadcasts/install_info.go +++ b/service/broadcasts/install_info.go @@ -9,11 +9,11 @@ import ( semver "github.com/hashicorp/go-version" - "github.com/safing/portbase/database" - "github.com/safing/portbase/database/query" - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/info" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/info" + "github.com/safing/portmaster/base/log" ) const installInfoDBKey = "core:status/install-info" diff --git a/service/broadcasts/module.go b/service/broadcasts/module.go index 360bc912b..9741d7e14 100644 --- a/service/broadcasts/module.go +++ b/service/broadcasts/module.go @@ -4,8 +4,8 @@ import ( "sync" "time" - "github.com/safing/portbase/database" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/modules" ) var ( diff --git a/service/broadcasts/notify.go b/service/broadcasts/notify.go index 4e359139d..57a7830ad 100644 --- a/service/broadcasts/notify.go +++ b/service/broadcasts/notify.go @@ -12,12 +12,12 @@ import ( "github.com/ghodss/yaml" - "github.com/safing/portbase/database" - "github.com/safing/portbase/database/accessor" - "github.com/safing/portbase/database/query" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" - "github.com/safing/portbase/notifications" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/accessor" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/service/updates" ) diff --git a/service/broadcasts/state.go b/service/broadcasts/state.go index afe8994cd..6669f20c0 100644 --- a/service/broadcasts/state.go +++ b/service/broadcasts/state.go @@ -5,7 +5,7 @@ import ( "sync" "time" - "github.com/safing/portbase/database/record" + "github.com/safing/portmaster/base/database/record" ) const broadcastStatesDBKey = "core:broadcasts/state" diff --git a/service/compat/api.go b/service/compat/api.go index 78365c057..998475fac 100644 --- a/service/compat/api.go +++ b/service/compat/api.go @@ -1,7 +1,7 @@ package compat import ( - "github.com/safing/portbase/api" + "github.com/safing/portmaster/base/api" ) func registerAPIEndpoints() error { diff --git a/service/compat/debug_default.go b/service/compat/debug_default.go index 82ac6bc32..478dea03d 100644 --- a/service/compat/debug_default.go +++ b/service/compat/debug_default.go @@ -2,7 +2,7 @@ package compat -import "github.com/safing/portbase/utils/debug" +import "github.com/safing/portmaster/base/utils/debug" // AddToDebugInfo adds compatibility data to the given debug.Info. func AddToDebugInfo(di *debug.Info) { diff --git a/service/compat/debug_linux.go b/service/compat/debug_linux.go index 825145098..2cabe95a1 100644 --- a/service/compat/debug_linux.go +++ b/service/compat/debug_linux.go @@ -3,7 +3,7 @@ package compat import ( "fmt" - "github.com/safing/portbase/utils/debug" + "github.com/safing/portmaster/base/utils/debug" ) // AddToDebugInfo adds compatibility data to the given debug.Info. diff --git a/service/compat/debug_windows.go b/service/compat/debug_windows.go index 176397f5b..44bdb1095 100644 --- a/service/compat/debug_windows.go +++ b/service/compat/debug_windows.go @@ -4,7 +4,7 @@ import ( "fmt" "strings" - "github.com/safing/portbase/utils/debug" + "github.com/safing/portmaster/base/utils/debug" ) // AddToDebugInfo adds compatibility data to the given debug.Info. diff --git a/service/compat/module.go b/service/compat/module.go index b8b95090c..0d9c3bd91 100644 --- a/service/compat/module.go +++ b/service/compat/module.go @@ -7,8 +7,8 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/resolver" ) diff --git a/service/compat/notify.go b/service/compat/notify.go index f26f0ea3e..38db2168b 100644 --- a/service/compat/notify.go +++ b/service/compat/notify.go @@ -8,10 +8,10 @@ import ( "sync" "time" - "github.com/safing/portbase/config" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" - "github.com/safing/portbase/notifications" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/service/process" "github.com/safing/portmaster/service/profile" ) diff --git a/service/compat/selfcheck.go b/service/compat/selfcheck.go index f4775cdc6..872053c8a 100644 --- a/service/compat/selfcheck.go +++ b/service/compat/selfcheck.go @@ -10,8 +10,8 @@ import ( "sync" "time" - "github.com/safing/portbase/log" - "github.com/safing/portbase/rng" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/rng" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/network/packet" "github.com/safing/portmaster/service/resolver" diff --git a/service/compat/wfpstate.go b/service/compat/wfpstate.go index 72f844c59..11a746b45 100644 --- a/service/compat/wfpstate.go +++ b/service/compat/wfpstate.go @@ -11,7 +11,7 @@ import ( "strings" "text/tabwriter" - "github.com/safing/portbase/utils/osdetail" + "github.com/safing/portmaster/base/utils/osdetail" ) // GetWFPState queries the system for the WFP state and returns a simplified diff --git a/service/core/api.go b/service/core/api.go index abc43dad9..bdacb6a14 100644 --- a/service/core/api.go +++ b/service/core/api.go @@ -9,13 +9,13 @@ import ( "net/url" "time" - "github.com/safing/portbase/api" - "github.com/safing/portbase/config" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" - "github.com/safing/portbase/notifications" - "github.com/safing/portbase/rng" - "github.com/safing/portbase/utils/debug" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/notifications" + "github.com/safing/portmaster/base/rng" + "github.com/safing/portmaster/base/utils/debug" "github.com/safing/portmaster/service/compat" "github.com/safing/portmaster/service/process" "github.com/safing/portmaster/service/resolver" diff --git a/service/core/base/databases.go b/service/core/base/databases.go index 56355125f..c0e9dca4a 100644 --- a/service/core/base/databases.go +++ b/service/core/base/databases.go @@ -1,9 +1,9 @@ package base import ( - "github.com/safing/portbase/database" - _ "github.com/safing/portbase/database/dbmodule" - _ "github.com/safing/portbase/database/storage/bbolt" + "github.com/safing/portmaster/base/database" + _ "github.com/safing/portmaster/base/database/dbmodule" + _ "github.com/safing/portmaster/base/database/storage/bbolt" ) // Default Values (changeable for testing). diff --git a/service/core/base/global.go b/service/core/base/global.go index 2de07ad21..a28827a39 100644 --- a/service/core/base/global.go +++ b/service/core/base/global.go @@ -5,10 +5,10 @@ import ( "flag" "fmt" - "github.com/safing/portbase/api" - "github.com/safing/portbase/dataroot" - "github.com/safing/portbase/info" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/dataroot" + "github.com/safing/portmaster/base/info" + "github.com/safing/portmaster/base/modules" ) // Default Values (changeable for testing). diff --git a/service/core/base/logs.go b/service/core/base/logs.go index dab2ebac3..91870e0e7 100644 --- a/service/core/base/logs.go +++ b/service/core/base/logs.go @@ -8,9 +8,9 @@ import ( "strings" "time" - "github.com/safing/portbase/dataroot" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/dataroot" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" ) const ( diff --git a/service/core/base/module.go b/service/core/base/module.go index 10870d174..581e74b4b 100644 --- a/service/core/base/module.go +++ b/service/core/base/module.go @@ -1,10 +1,10 @@ package base import ( - _ "github.com/safing/portbase/config" - _ "github.com/safing/portbase/metrics" - "github.com/safing/portbase/modules" - _ "github.com/safing/portbase/rng" + _ "github.com/safing/portmaster/base/config" + _ "github.com/safing/portmaster/base/metrics" + "github.com/safing/portmaster/base/modules" + _ "github.com/safing/portmaster/base/rng" ) var module *modules.Module diff --git a/service/core/config.go b/service/core/config.go index 28fb24841..da7c8b9ca 100644 --- a/service/core/config.go +++ b/service/core/config.go @@ -6,8 +6,8 @@ import ( locale "github.com/Xuanwo/go-locale" "golang.org/x/exp/slices" - "github.com/safing/portbase/config" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/log" ) // Configuration Keys. diff --git a/service/core/core.go b/service/core/core.go index ff5357594..fbeee9bc1 100644 --- a/service/core/core.go +++ b/service/core/core.go @@ -5,10 +5,10 @@ import ( "fmt" "time" - "github.com/safing/portbase/log" - "github.com/safing/portbase/metrics" - "github.com/safing/portbase/modules" - "github.com/safing/portbase/modules/subsystems" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/metrics" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/modules/subsystems" _ "github.com/safing/portmaster/service/broadcasts" _ "github.com/safing/portmaster/service/netenv" _ "github.com/safing/portmaster/service/netquery" diff --git a/service/core/os_windows.go b/service/core/os_windows.go index 41f1f5945..bd5313148 100644 --- a/service/core/os_windows.go +++ b/service/core/os_windows.go @@ -1,8 +1,8 @@ package core import ( - "github.com/safing/portbase/log" - "github.com/safing/portbase/utils/osdetail" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/utils/osdetail" ) // only return on Fatal error! diff --git a/service/core/pmtesting/testing.go b/service/core/pmtesting/testing.go index 16253f865..131410610 100644 --- a/service/core/pmtesting/testing.go +++ b/service/core/pmtesting/testing.go @@ -23,10 +23,10 @@ import ( "runtime/pprof" "testing" - _ "github.com/safing/portbase/database/storage/hashmap" - "github.com/safing/portbase/dataroot" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" + _ "github.com/safing/portmaster/base/database/storage/hashmap" + "github.com/safing/portmaster/base/dataroot" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/service/core/base" ) diff --git a/service/firewall/api.go b/service/firewall/api.go index 24fa69bae..c37dbf690 100644 --- a/service/firewall/api.go +++ b/service/firewall/api.go @@ -10,10 +10,10 @@ import ( "strings" "time" - "github.com/safing/portbase/api" - "github.com/safing/portbase/dataroot" - "github.com/safing/portbase/log" - "github.com/safing/portbase/utils" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/dataroot" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/utils" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/network/netutils" "github.com/safing/portmaster/service/network/packet" diff --git a/service/firewall/config.go b/service/firewall/config.go index 960c000b5..34cbe4e83 100644 --- a/service/firewall/config.go +++ b/service/firewall/config.go @@ -3,9 +3,9 @@ package firewall import ( "github.com/tevino/abool" - "github.com/safing/portbase/api" - "github.com/safing/portbase/config" - "github.com/safing/portbase/notifications" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/service/core" "github.com/safing/portmaster/spn/captain" ) diff --git a/service/firewall/dns.go b/service/firewall/dns.go index 3712165db..8a6e19738 100644 --- a/service/firewall/dns.go +++ b/service/firewall/dns.go @@ -8,8 +8,8 @@ import ( "github.com/miekg/dns" - "github.com/safing/portbase/database" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/network" "github.com/safing/portmaster/service/network/netutils" "github.com/safing/portmaster/service/profile" diff --git a/service/firewall/interception/ebpf/bandwidth/interface.go b/service/firewall/interception/ebpf/bandwidth/interface.go index e1473dbd9..124b21a42 100644 --- a/service/firewall/interception/ebpf/bandwidth/interface.go +++ b/service/firewall/interception/ebpf/bandwidth/interface.go @@ -15,7 +15,7 @@ import ( "github.com/cilium/ebpf/rlimit" "golang.org/x/sys/unix" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/network/packet" ) diff --git a/service/firewall/interception/ebpf/connection_listener/worker.go b/service/firewall/interception/ebpf/connection_listener/worker.go index aadfd57f8..e7019c3fd 100644 --- a/service/firewall/interception/ebpf/connection_listener/worker.go +++ b/service/firewall/interception/ebpf/connection_listener/worker.go @@ -14,7 +14,7 @@ import ( "github.com/cilium/ebpf/ringbuf" "github.com/cilium/ebpf/rlimit" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/network/packet" ) diff --git a/service/firewall/interception/ebpf/exec/exec.go b/service/firewall/interception/ebpf/exec/exec.go index f6dbb3834..3e5433442 100644 --- a/service/firewall/interception/ebpf/exec/exec.go +++ b/service/firewall/interception/ebpf/exec/exec.go @@ -17,7 +17,7 @@ import ( "github.com/hashicorp/go-multierror" "golang.org/x/sys/unix" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" ) //go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc clang -cflags "-O2 -g -Wall -Werror" bpf ../programs/exec.c diff --git a/service/firewall/interception/interception_default.go b/service/firewall/interception/interception_default.go index a4a93f442..8755f225b 100644 --- a/service/firewall/interception/interception_default.go +++ b/service/firewall/interception/interception_default.go @@ -3,7 +3,7 @@ package interception import ( - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/network" "github.com/safing/portmaster/service/network/packet" ) diff --git a/service/firewall/interception/interception_windows.go b/service/firewall/interception/interception_windows.go index ab7535188..bd36ffa1a 100644 --- a/service/firewall/interception/interception_windows.go +++ b/service/firewall/interception/interception_windows.go @@ -5,7 +5,7 @@ import ( "fmt" "time" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" kext1 "github.com/safing/portmaster/service/firewall/interception/windowskext" kext2 "github.com/safing/portmaster/service/firewall/interception/windowskext2" "github.com/safing/portmaster/service/network" diff --git a/service/firewall/interception/introspection.go b/service/firewall/interception/introspection.go index c8361e1ca..088a33c8c 100644 --- a/service/firewall/interception/introspection.go +++ b/service/firewall/interception/introspection.go @@ -7,7 +7,7 @@ import ( "sync" "time" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" ) var ( diff --git a/service/firewall/interception/module.go b/service/firewall/interception/module.go index 2802defa8..57c5a9611 100644 --- a/service/firewall/interception/module.go +++ b/service/firewall/interception/module.go @@ -3,8 +3,8 @@ package interception import ( "flag" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/service/network/packet" ) diff --git a/service/firewall/interception/nfq/conntrack.go b/service/firewall/interception/nfq/conntrack.go index ea7761e4f..776f01176 100644 --- a/service/firewall/interception/nfq/conntrack.go +++ b/service/firewall/interception/nfq/conntrack.go @@ -9,7 +9,7 @@ import ( ct "github.com/florianl/go-conntrack" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/network" ) diff --git a/service/firewall/interception/nfq/nfq.go b/service/firewall/interception/nfq/nfq.go index f75799209..22a5b3908 100644 --- a/service/firewall/interception/nfq/nfq.go +++ b/service/firewall/interception/nfq/nfq.go @@ -14,7 +14,7 @@ import ( "github.com/tevino/abool" "golang.org/x/sys/unix" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" pmpacket "github.com/safing/portmaster/service/network/packet" "github.com/safing/portmaster/service/process" ) diff --git a/service/firewall/interception/nfq/packet.go b/service/firewall/interception/nfq/packet.go index af3d5fac2..49ceda164 100644 --- a/service/firewall/interception/nfq/packet.go +++ b/service/firewall/interception/nfq/packet.go @@ -10,7 +10,7 @@ import ( "github.com/florianl/go-nfqueue" "github.com/tevino/abool" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" pmpacket "github.com/safing/portmaster/service/network/packet" ) diff --git a/service/firewall/interception/nfqueue_linux.go b/service/firewall/interception/nfqueue_linux.go index 537bbcb79..e4c136d9d 100644 --- a/service/firewall/interception/nfqueue_linux.go +++ b/service/firewall/interception/nfqueue_linux.go @@ -10,7 +10,7 @@ import ( "github.com/coreos/go-iptables/iptables" "github.com/hashicorp/go-multierror" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/firewall/interception/nfq" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/network/packet" diff --git a/service/firewall/interception/windowskext/bandwidth_stats.go b/service/firewall/interception/windowskext/bandwidth_stats.go index a29e50d99..68bf58055 100644 --- a/service/firewall/interception/windowskext/bandwidth_stats.go +++ b/service/firewall/interception/windowskext/bandwidth_stats.go @@ -9,7 +9,7 @@ import ( "context" "time" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/network/packet" ) diff --git a/service/firewall/interception/windowskext/handler.go b/service/firewall/interception/windowskext/handler.go index a5a8de74d..170ee77bc 100644 --- a/service/firewall/interception/windowskext/handler.go +++ b/service/firewall/interception/windowskext/handler.go @@ -16,7 +16,7 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/network" "github.com/safing/portmaster/service/network/packet" ) diff --git a/service/firewall/interception/windowskext/kext.go b/service/firewall/interception/windowskext/kext.go index 34badd6d7..f5ca9bc6a 100644 --- a/service/firewall/interception/windowskext/kext.go +++ b/service/firewall/interception/windowskext/kext.go @@ -10,7 +10,7 @@ import ( "syscall" "unsafe" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/network" "github.com/safing/portmaster/service/network/packet" "golang.org/x/sys/windows" diff --git a/service/firewall/interception/windowskext/packet.go b/service/firewall/interception/windowskext/packet.go index 5f96e784f..5942d7d91 100644 --- a/service/firewall/interception/windowskext/packet.go +++ b/service/firewall/interception/windowskext/packet.go @@ -8,7 +8,7 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/network" "github.com/safing/portmaster/service/network/packet" ) diff --git a/service/firewall/interception/windowskext/service.go b/service/firewall/interception/windowskext/service.go index e3e4ac2a9..90b1e9729 100644 --- a/service/firewall/interception/windowskext/service.go +++ b/service/firewall/interception/windowskext/service.go @@ -8,7 +8,7 @@ import ( "syscall" "time" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "golang.org/x/sys/windows" ) diff --git a/service/firewall/interception/windowskext2/handler.go b/service/firewall/interception/windowskext2/handler.go index bb6348dd2..3488d7ff1 100644 --- a/service/firewall/interception/windowskext2/handler.go +++ b/service/firewall/interception/windowskext2/handler.go @@ -13,7 +13,7 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/network/packet" ) diff --git a/service/firewall/interception/windowskext2/kext.go b/service/firewall/interception/windowskext2/kext.go index fd6adb721..1c3a4d4df 100644 --- a/service/firewall/interception/windowskext2/kext.go +++ b/service/firewall/interception/windowskext2/kext.go @@ -6,7 +6,7 @@ package windowskext import ( "fmt" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/network" "github.com/safing/portmaster/windows_kext/kextinterface" "golang.org/x/sys/windows" diff --git a/service/firewall/interception/windowskext2/packet.go b/service/firewall/interception/windowskext2/packet.go index 3ea9c0095..52a7a2a79 100644 --- a/service/firewall/interception/windowskext2/packet.go +++ b/service/firewall/interception/windowskext2/packet.go @@ -8,7 +8,7 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/network/packet" "github.com/safing/portmaster/windows_kext/kextinterface" ) diff --git a/service/firewall/master.go b/service/firewall/master.go index 6549194f7..95d22212f 100644 --- a/service/firewall/master.go +++ b/service/firewall/master.go @@ -11,7 +11,7 @@ import ( "github.com/agext/levenshtein" "golang.org/x/net/publicsuffix" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/detection/dga" "github.com/safing/portmaster/service/intel/customlists" "github.com/safing/portmaster/service/intel/filterlists" diff --git a/service/firewall/module.go b/service/firewall/module.go index 3b80fc889..9430b48af 100644 --- a/service/firewall/module.go +++ b/service/firewall/module.go @@ -7,10 +7,10 @@ import ( "path/filepath" "strings" - "github.com/safing/portbase/config" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" - "github.com/safing/portbase/modules/subsystems" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/modules/subsystems" _ "github.com/safing/portmaster/service/core" "github.com/safing/portmaster/service/network" "github.com/safing/portmaster/service/profile" diff --git a/service/firewall/packet_handler.go b/service/firewall/packet_handler.go index bfb8473c0..faf3dceec 100644 --- a/service/firewall/packet_handler.go +++ b/service/firewall/packet_handler.go @@ -12,7 +12,7 @@ import ( "github.com/google/gopacket/layers" "github.com/tevino/abool" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/compat" _ "github.com/safing/portmaster/service/core/base" "github.com/safing/portmaster/service/firewall/inspection" diff --git a/service/firewall/prompt.go b/service/firewall/prompt.go index 51d6a12a2..984479eed 100644 --- a/service/firewall/prompt.go +++ b/service/firewall/prompt.go @@ -6,8 +6,8 @@ import ( "sync" "time" - "github.com/safing/portbase/log" - "github.com/safing/portbase/notifications" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/service/intel" "github.com/safing/portmaster/service/network" "github.com/safing/portmaster/service/profile" diff --git a/service/firewall/tunnel.go b/service/firewall/tunnel.go index 46b5864a8..f975a2b97 100644 --- a/service/firewall/tunnel.go +++ b/service/firewall/tunnel.go @@ -4,7 +4,7 @@ import ( "context" "errors" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/intel" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/network" diff --git a/service/intel/block_reason.go b/service/intel/block_reason.go index 5cabbddf8..65f92755e 100644 --- a/service/intel/block_reason.go +++ b/service/intel/block_reason.go @@ -8,7 +8,7 @@ import ( "github.com/miekg/dns" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/nameserver/nsutil" ) diff --git a/service/intel/customlists/config.go b/service/intel/customlists/config.go index 140b788ff..ad8d0ae55 100644 --- a/service/intel/customlists/config.go +++ b/service/intel/customlists/config.go @@ -1,7 +1,7 @@ package customlists import ( - "github.com/safing/portbase/config" + "github.com/safing/portmaster/base/config" ) var ( diff --git a/service/intel/customlists/lists.go b/service/intel/customlists/lists.go index 33170dd70..cf807248a 100644 --- a/service/intel/customlists/lists.go +++ b/service/intel/customlists/lists.go @@ -10,8 +10,8 @@ import ( "github.com/miekg/dns" - "github.com/safing/portbase/log" - "github.com/safing/portbase/notifications" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/service/network/netutils" ) diff --git a/service/intel/customlists/module.go b/service/intel/customlists/module.go index d33ac4f29..988617037 100644 --- a/service/intel/customlists/module.go +++ b/service/intel/customlists/module.go @@ -12,8 +12,8 @@ import ( "golang.org/x/net/publicsuffix" - "github.com/safing/portbase/api" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/modules" ) var module *modules.Module diff --git a/service/intel/entity.go b/service/intel/entity.go index df67edfcc..986ca1401 100644 --- a/service/intel/entity.go +++ b/service/intel/entity.go @@ -10,7 +10,7 @@ import ( "golang.org/x/net/publicsuffix" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/intel/filterlists" "github.com/safing/portmaster/service/intel/geoip" "github.com/safing/portmaster/service/network/netutils" diff --git a/service/intel/filterlists/bloom.go b/service/intel/filterlists/bloom.go index 1b0fd99fa..77e7f2f4a 100644 --- a/service/intel/filterlists/bloom.go +++ b/service/intel/filterlists/bloom.go @@ -8,8 +8,8 @@ import ( "github.com/tannerryan/ring" - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/log" ) var defaultFilter = newScopedBloom() diff --git a/service/intel/filterlists/cache_version.go b/service/intel/filterlists/cache_version.go index b1f40dfc2..c6c28db5a 100644 --- a/service/intel/filterlists/cache_version.go +++ b/service/intel/filterlists/cache_version.go @@ -6,8 +6,8 @@ import ( "github.com/hashicorp/go-version" - "github.com/safing/portbase/database" - "github.com/safing/portbase/database/record" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/record" ) const resetVersion = "v0.6.0" diff --git a/service/intel/filterlists/database.go b/service/intel/filterlists/database.go index 8b08f3238..5f55323c6 100644 --- a/service/intel/filterlists/database.go +++ b/service/intel/filterlists/database.go @@ -11,10 +11,10 @@ import ( "golang.org/x/sync/errgroup" - "github.com/safing/portbase/database" - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/log" - "github.com/safing/portbase/updater" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/updater" "github.com/safing/portmaster/service/updates" ) diff --git a/service/intel/filterlists/decoder.go b/service/intel/filterlists/decoder.go index e66d8d840..4083a237f 100644 --- a/service/intel/filterlists/decoder.go +++ b/service/intel/filterlists/decoder.go @@ -8,8 +8,8 @@ import ( "fmt" "io" - "github.com/safing/portbase/formats/dsd" - "github.com/safing/portbase/utils" + "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/base/utils" ) type listEntry struct { diff --git a/service/intel/filterlists/index.go b/service/intel/filterlists/index.go index e5a593b62..74770fe27 100644 --- a/service/intel/filterlists/index.go +++ b/service/intel/filterlists/index.go @@ -7,11 +7,11 @@ import ( "strings" "sync" - "github.com/safing/portbase/database" - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/formats/dsd" - "github.com/safing/portbase/log" - "github.com/safing/portbase/updater" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/updater" "github.com/safing/portmaster/service/updates" ) diff --git a/service/intel/filterlists/lookup.go b/service/intel/filterlists/lookup.go index 8b4e0bd7d..2b3348165 100644 --- a/service/intel/filterlists/lookup.go +++ b/service/intel/filterlists/lookup.go @@ -4,8 +4,8 @@ import ( "errors" "net" - "github.com/safing/portbase/database" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/log" ) // lookupBlockLists loads the entity record for key from diff --git a/service/intel/filterlists/module.go b/service/intel/filterlists/module.go index a7846ee4a..0fa6aadf8 100644 --- a/service/intel/filterlists/module.go +++ b/service/intel/filterlists/module.go @@ -6,8 +6,8 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/updates" ) diff --git a/service/intel/filterlists/record.go b/service/intel/filterlists/record.go index 418790d6b..be2e7c5e3 100644 --- a/service/intel/filterlists/record.go +++ b/service/intel/filterlists/record.go @@ -4,7 +4,7 @@ import ( "fmt" "sync" - "github.com/safing/portbase/database/record" + "github.com/safing/portmaster/base/database/record" ) type entityRecord struct { diff --git a/service/intel/filterlists/updater.go b/service/intel/filterlists/updater.go index 7d15e85ec..cbe9fc3f8 100644 --- a/service/intel/filterlists/updater.go +++ b/service/intel/filterlists/updater.go @@ -10,11 +10,11 @@ import ( "github.com/hashicorp/go-version" "github.com/tevino/abool" - "github.com/safing/portbase/database" - "github.com/safing/portbase/database/query" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" - "github.com/safing/portbase/updater" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/updater" ) var updateInProgress = abool.New() diff --git a/service/intel/geoip/database.go b/service/intel/geoip/database.go index 57b085786..3101f7dc0 100644 --- a/service/intel/geoip/database.go +++ b/service/intel/geoip/database.go @@ -8,8 +8,8 @@ import ( maxminddb "github.com/oschwald/maxminddb-golang" - "github.com/safing/portbase/log" - "github.com/safing/portbase/updater" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/updater" "github.com/safing/portmaster/service/updates" ) diff --git a/service/intel/geoip/location.go b/service/intel/geoip/location.go index 5295584ee..c6ea7d9a2 100644 --- a/service/intel/geoip/location.go +++ b/service/intel/geoip/location.go @@ -7,7 +7,7 @@ import ( "github.com/umahmood/haversine" - "github.com/safing/portbase/utils" + "github.com/safing/portmaster/base/utils" ) const ( diff --git a/service/intel/geoip/module.go b/service/intel/geoip/module.go index c5d44e004..7141d476e 100644 --- a/service/intel/geoip/module.go +++ b/service/intel/geoip/module.go @@ -3,8 +3,8 @@ package geoip import ( "context" - "github.com/safing/portbase/api" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/service/updates" ) diff --git a/service/intel/geoip/regions.go b/service/intel/geoip/regions.go index 879da23d1..752fd1b21 100644 --- a/service/intel/geoip/regions.go +++ b/service/intel/geoip/regions.go @@ -1,7 +1,7 @@ package geoip import ( - "github.com/safing/portbase/utils" + "github.com/safing/portmaster/base/utils" ) // IsRegionalNeighbor returns whether the supplied location is a regional neighbor. diff --git a/service/intel/geoip/regions_test.go b/service/intel/geoip/regions_test.go index 6ea7ae341..f7c59c044 100644 --- a/service/intel/geoip/regions_test.go +++ b/service/intel/geoip/regions_test.go @@ -3,7 +3,7 @@ package geoip import ( "testing" - "github.com/safing/portbase/utils" + "github.com/safing/portmaster/base/utils" ) func TestRegions(t *testing.T) { diff --git a/service/intel/module.go b/service/intel/module.go index 35c2d75c1..87f872b51 100644 --- a/service/intel/module.go +++ b/service/intel/module.go @@ -1,7 +1,7 @@ package intel import ( - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/modules" _ "github.com/safing/portmaster/service/intel/customlists" ) diff --git a/service/nameserver/config.go b/service/nameserver/config.go index 3e13044a5..3affc7cfc 100644 --- a/service/nameserver/config.go +++ b/service/nameserver/config.go @@ -4,7 +4,7 @@ import ( "flag" "runtime" - "github.com/safing/portbase/config" + "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/service/core" ) diff --git a/service/nameserver/conflict.go b/service/nameserver/conflict.go index f716f7ebf..3837a938e 100644 --- a/service/nameserver/conflict.go +++ b/service/nameserver/conflict.go @@ -6,7 +6,7 @@ import ( processInfo "github.com/shirou/gopsutil/process" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/network/packet" "github.com/safing/portmaster/service/network/state" ) diff --git a/service/nameserver/metrics.go b/service/nameserver/metrics.go index eca11bd2b..204571282 100644 --- a/service/nameserver/metrics.go +++ b/service/nameserver/metrics.go @@ -1,9 +1,9 @@ package nameserver import ( - "github.com/safing/portbase/api" - "github.com/safing/portbase/config" - "github.com/safing/portbase/metrics" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/metrics" ) var ( diff --git a/service/nameserver/module.go b/service/nameserver/module.go index 8380583b1..6dcd320dd 100644 --- a/service/nameserver/module.go +++ b/service/nameserver/module.go @@ -10,10 +10,10 @@ import ( "github.com/miekg/dns" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" - "github.com/safing/portbase/modules/subsystems" - "github.com/safing/portbase/notifications" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/modules/subsystems" + "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/service/compat" "github.com/safing/portmaster/service/firewall" "github.com/safing/portmaster/service/netenv" diff --git a/service/nameserver/nameserver.go b/service/nameserver/nameserver.go index 66bccd8e3..903541f5f 100644 --- a/service/nameserver/nameserver.go +++ b/service/nameserver/nameserver.go @@ -10,7 +10,7 @@ import ( "github.com/miekg/dns" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/firewall" "github.com/safing/portmaster/service/nameserver/nsutil" "github.com/safing/portmaster/service/netenv" diff --git a/service/nameserver/nsutil/nsutil.go b/service/nameserver/nsutil/nsutil.go index 7c440c11f..0abd629b0 100644 --- a/service/nameserver/nsutil/nsutil.go +++ b/service/nameserver/nsutil/nsutil.go @@ -10,7 +10,7 @@ import ( "github.com/miekg/dns" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" ) // ErrNilRR is returned when a parsed RR is nil. diff --git a/service/nameserver/response.go b/service/nameserver/response.go index 85daf1409..c7fe196e7 100644 --- a/service/nameserver/response.go +++ b/service/nameserver/response.go @@ -6,7 +6,7 @@ import ( "github.com/miekg/dns" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/nameserver/nsutil" ) diff --git a/service/netenv/adresses.go b/service/netenv/adresses.go index 902dd0da2..da4895e64 100644 --- a/service/netenv/adresses.go +++ b/service/netenv/adresses.go @@ -6,7 +6,7 @@ import ( "sync" "time" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/network/netutils" ) diff --git a/service/netenv/api.go b/service/netenv/api.go index 20a2f6883..e5dca1870 100644 --- a/service/netenv/api.go +++ b/service/netenv/api.go @@ -3,7 +3,7 @@ package netenv import ( "errors" - "github.com/safing/portbase/api" + "github.com/safing/portmaster/base/api" ) func registerAPIEndpoints() error { diff --git a/service/netenv/dbus_linux.go b/service/netenv/dbus_linux.go index 730cb5641..03483de40 100644 --- a/service/netenv/dbus_linux.go +++ b/service/netenv/dbus_linux.go @@ -10,7 +10,7 @@ import ( "github.com/godbus/dbus/v5" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" ) var ( diff --git a/service/netenv/environment_linux.go b/service/netenv/environment_linux.go index 5f39875b0..a7e47e5b4 100644 --- a/service/netenv/environment_linux.go +++ b/service/netenv/environment_linux.go @@ -10,7 +10,7 @@ import ( "github.com/miekg/dns" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/network/netutils" ) diff --git a/service/netenv/environment_windows.go b/service/netenv/environment_windows.go index 21cd6f2b4..90f778e92 100644 --- a/service/netenv/environment_windows.go +++ b/service/netenv/environment_windows.go @@ -10,8 +10,8 @@ import ( "sync" "time" - "github.com/safing/portbase/log" - "github.com/safing/portbase/utils/osdetail" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/utils/osdetail" ) // Gateways returns the currently active gateways. diff --git a/service/netenv/icmp_listener.go b/service/netenv/icmp_listener.go index d1716d8a8..e5d40dd50 100644 --- a/service/netenv/icmp_listener.go +++ b/service/netenv/icmp_listener.go @@ -6,7 +6,7 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/network/packet" ) diff --git a/service/netenv/location.go b/service/netenv/location.go index 276e33a32..6026da33c 100644 --- a/service/netenv/location.go +++ b/service/netenv/location.go @@ -12,8 +12,8 @@ import ( "golang.org/x/net/icmp" "golang.org/x/net/ipv4" - "github.com/safing/portbase/log" - "github.com/safing/portbase/rng" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/rng" "github.com/safing/portmaster/service/intel/geoip" "github.com/safing/portmaster/service/network/netutils" "github.com/safing/portmaster/service/network/packet" diff --git a/service/netenv/main.go b/service/netenv/main.go index 3363754a1..c94062a76 100644 --- a/service/netenv/main.go +++ b/service/netenv/main.go @@ -3,8 +3,8 @@ package netenv import ( "github.com/tevino/abool" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" ) // Event Names. diff --git a/service/netenv/network-change.go b/service/netenv/network-change.go index 1188db2d7..143256f90 100644 --- a/service/netenv/network-change.go +++ b/service/netenv/network-change.go @@ -7,8 +7,8 @@ import ( "io" "time" - "github.com/safing/portbase/log" - "github.com/safing/portbase/utils" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/utils" ) var ( diff --git a/service/netenv/online-status.go b/service/netenv/online-status.go index fac5e170e..018d3dc16 100644 --- a/service/netenv/online-status.go +++ b/service/netenv/online-status.go @@ -13,8 +13,8 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/log" - "github.com/safing/portbase/notifications" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/service/network/netutils" "github.com/safing/portmaster/service/updates" ) diff --git a/service/netquery/database.go b/service/netquery/database.go index cb9f40397..a1cd6aea9 100644 --- a/service/netquery/database.go +++ b/service/netquery/database.go @@ -16,9 +16,9 @@ import ( "zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite/sqlitex" - "github.com/safing/portbase/config" - "github.com/safing/portbase/dataroot" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/dataroot" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/netquery/orm" "github.com/safing/portmaster/service/network" "github.com/safing/portmaster/service/network/netutils" diff --git a/service/netquery/manager.go b/service/netquery/manager.go index d1809116a..31f2f7f8d 100644 --- a/service/netquery/manager.go +++ b/service/netquery/manager.go @@ -6,10 +6,10 @@ import ( "fmt" "time" - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/formats/dsd" - "github.com/safing/portbase/log" - "github.com/safing/portbase/runtime" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/runtime" "github.com/safing/portmaster/service/network" ) diff --git a/service/netquery/module_api.go b/service/netquery/module_api.go index 00950a019..5d58709b8 100644 --- a/service/netquery/module_api.go +++ b/service/netquery/module_api.go @@ -9,14 +9,14 @@ import ( "github.com/hashicorp/go-multierror" servertiming "github.com/mitchellh/go-server-timing" - "github.com/safing/portbase/api" - "github.com/safing/portbase/config" - "github.com/safing/portbase/database" - "github.com/safing/portbase/database/query" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" - "github.com/safing/portbase/modules/subsystems" - "github.com/safing/portbase/runtime" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/modules/subsystems" + "github.com/safing/portmaster/base/runtime" "github.com/safing/portmaster/service/network" ) diff --git a/service/netquery/orm/schema_builder.go b/service/netquery/orm/schema_builder.go index 893dab2e7..90805c80e 100644 --- a/service/netquery/orm/schema_builder.go +++ b/service/netquery/orm/schema_builder.go @@ -9,7 +9,7 @@ import ( "zombiezen.com/go/sqlite" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" ) var errSkipStructField = errors.New("struct field should be skipped") diff --git a/service/netquery/query_handler.go b/service/netquery/query_handler.go index e996c1835..1e3afbff2 100644 --- a/service/netquery/query_handler.go +++ b/service/netquery/query_handler.go @@ -13,7 +13,7 @@ import ( "github.com/hashicorp/go-multierror" servertiming "github.com/mitchellh/go-server-timing" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/netquery/orm" ) diff --git a/service/netquery/runtime_query_runner.go b/service/netquery/runtime_query_runner.go index 67ba449b8..0d09b51d7 100644 --- a/service/netquery/runtime_query_runner.go +++ b/service/netquery/runtime_query_runner.go @@ -6,10 +6,10 @@ import ( "fmt" "strings" - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/formats/dsd" - "github.com/safing/portbase/log" - "github.com/safing/portbase/runtime" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/runtime" "github.com/safing/portmaster/service/netquery/orm" ) diff --git a/service/network/api.go b/service/network/api.go index 878db313c..0624f806f 100644 --- a/service/network/api.go +++ b/service/network/api.go @@ -8,10 +8,10 @@ import ( "strings" "time" - "github.com/safing/portbase/api" - "github.com/safing/portbase/config" - "github.com/safing/portbase/database/query" - "github.com/safing/portbase/utils/debug" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/utils/debug" "github.com/safing/portmaster/service/network/state" "github.com/safing/portmaster/service/process" "github.com/safing/portmaster/service/resolver" diff --git a/service/network/clean.go b/service/network/clean.go index c2777164c..3b04990b9 100644 --- a/service/network/clean.go +++ b/service/network/clean.go @@ -4,7 +4,7 @@ import ( "context" "time" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/network/packet" "github.com/safing/portmaster/service/network/state" "github.com/safing/portmaster/service/process" diff --git a/service/network/connection.go b/service/network/connection.go index b83ee5423..deff0ae94 100644 --- a/service/network/connection.go +++ b/service/network/connection.go @@ -11,9 +11,9 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/log" - "github.com/safing/portbase/notifications" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/service/intel" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/network/netutils" diff --git a/service/network/database.go b/service/network/database.go index 9b098d487..33fd8038d 100644 --- a/service/network/database.go +++ b/service/network/database.go @@ -6,11 +6,11 @@ import ( "strconv" "strings" - "github.com/safing/portbase/database" - "github.com/safing/portbase/database/iterator" - "github.com/safing/portbase/database/query" - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/database/storage" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/iterator" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/database/storage" "github.com/safing/portmaster/service/process" ) diff --git a/service/network/dns.go b/service/network/dns.go index 201dd25b1..9b0dbe94e 100644 --- a/service/network/dns.go +++ b/service/network/dns.go @@ -10,7 +10,7 @@ import ( "github.com/miekg/dns" "golang.org/x/exp/slices" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/nameserver/nsutil" "github.com/safing/portmaster/service/network/packet" "github.com/safing/portmaster/service/process" diff --git a/service/network/iphelper/tables.go b/service/network/iphelper/tables.go index b221d2ef3..8315ca781 100644 --- a/service/network/iphelper/tables.go +++ b/service/network/iphelper/tables.go @@ -10,7 +10,7 @@ import ( "sync" "unsafe" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/network/socket" "golang.org/x/sys/windows" diff --git a/service/network/metrics.go b/service/network/metrics.go index 5ffa18800..e64ed1638 100644 --- a/service/network/metrics.go +++ b/service/network/metrics.go @@ -1,9 +1,9 @@ package network import ( - "github.com/safing/portbase/api" - "github.com/safing/portbase/config" - "github.com/safing/portbase/metrics" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/metrics" "github.com/safing/portmaster/service/process" ) diff --git a/service/network/module.go b/service/network/module.go index bebcb4670..2863cf99f 100644 --- a/service/network/module.go +++ b/service/network/module.go @@ -6,8 +6,8 @@ import ( "strings" "sync" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/network/state" "github.com/safing/portmaster/service/profile" diff --git a/service/network/ports.go b/service/network/ports.go index ab870ff0f..11c322413 100644 --- a/service/network/ports.go +++ b/service/network/ports.go @@ -1,8 +1,8 @@ package network import ( - "github.com/safing/portbase/log" - "github.com/safing/portbase/rng" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/rng" ) // GetUnusedLocalPort returns a local port of the specified protocol that is diff --git a/service/network/proc/findpid.go b/service/network/proc/findpid.go index e5cd5185f..13cf0aa4a 100644 --- a/service/network/proc/findpid.go +++ b/service/network/proc/findpid.go @@ -8,7 +8,7 @@ import ( "os" "strconv" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/network/socket" ) diff --git a/service/network/proc/pids_by_user.go b/service/network/proc/pids_by_user.go index a55d0489e..7fd451e7c 100644 --- a/service/network/proc/pids_by_user.go +++ b/service/network/proc/pids_by_user.go @@ -11,8 +11,8 @@ import ( "syscall" "time" - "github.com/safing/portbase/log" - "github.com/safing/portbase/utils" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/utils" ) var ( diff --git a/service/network/proc/tables.go b/service/network/proc/tables.go index 2569a7f0a..0fbfc79cd 100644 --- a/service/network/proc/tables.go +++ b/service/network/proc/tables.go @@ -12,7 +12,7 @@ import ( "strings" "unicode" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/network/socket" ) diff --git a/service/network/state/info.go b/service/network/state/info.go index 306c36a05..1d680cc4c 100644 --- a/service/network/state/info.go +++ b/service/network/state/info.go @@ -3,7 +3,7 @@ package state import ( "sync" - "github.com/safing/portbase/database/record" + "github.com/safing/portmaster/base/database/record" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/network/socket" ) diff --git a/service/network/state/system_default.go b/service/network/state/system_default.go index 9ccf96c92..5ca01534e 100644 --- a/service/network/state/system_default.go +++ b/service/network/state/system_default.go @@ -6,7 +6,7 @@ package state import ( "time" - "github.com/safing/portbase/config" + "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/service/network/socket" ) diff --git a/service/network/state/tcp.go b/service/network/state/tcp.go index 33e053be7..5d08b0541 100644 --- a/service/network/state/tcp.go +++ b/service/network/state/tcp.go @@ -6,8 +6,8 @@ import ( "sync/atomic" "time" - "github.com/safing/portbase/log" - "github.com/safing/portbase/utils" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/utils" "github.com/safing/portmaster/service/network/socket" ) diff --git a/service/network/state/udp.go b/service/network/state/udp.go index ce7139e47..1c534b7f2 100644 --- a/service/network/state/udp.go +++ b/service/network/state/udp.go @@ -8,8 +8,8 @@ import ( "sync/atomic" "time" - "github.com/safing/portbase/log" - "github.com/safing/portbase/utils" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/utils" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/network/packet" "github.com/safing/portmaster/service/network/socket" diff --git a/service/process/api.go b/service/process/api.go index a2aca7f63..a17b3c70d 100644 --- a/service/process/api.go +++ b/service/process/api.go @@ -5,7 +5,7 @@ import ( "net/http" "strconv" - "github.com/safing/portbase/api" + "github.com/safing/portmaster/base/api" "github.com/safing/portmaster/service/profile" ) diff --git a/service/process/config.go b/service/process/config.go index 3e91aa550..69dc3c568 100644 --- a/service/process/config.go +++ b/service/process/config.go @@ -1,7 +1,7 @@ package process import ( - "github.com/safing/portbase/config" + "github.com/safing/portmaster/base/config" ) // Configuration Keys. diff --git a/service/process/database.go b/service/process/database.go index 091d1470d..689a3faed 100644 --- a/service/process/database.go +++ b/service/process/database.go @@ -11,8 +11,8 @@ import ( processInfo "github.com/shirou/gopsutil/process" "github.com/tevino/abool" - "github.com/safing/portbase/database" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/profile" ) diff --git a/service/process/find.go b/service/process/find.go index 8438ecfd0..0a3ff7af7 100644 --- a/service/process/find.go +++ b/service/process/find.go @@ -6,8 +6,8 @@ import ( "net" "time" - "github.com/safing/portbase/api" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/network/netutils" "github.com/safing/portmaster/service/network/packet" "github.com/safing/portmaster/service/network/reference" diff --git a/service/process/module.go b/service/process/module.go index cef4fe2a0..be97b26ea 100644 --- a/service/process/module.go +++ b/service/process/module.go @@ -3,7 +3,7 @@ package process import ( "os" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/service/updates" ) diff --git a/service/process/process.go b/service/process/process.go index 4508310e3..60dac7ebf 100644 --- a/service/process/process.go +++ b/service/process/process.go @@ -13,8 +13,8 @@ import ( processInfo "github.com/shirou/gopsutil/process" "golang.org/x/sync/singleflight" - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/profile" ) diff --git a/service/process/process_linux.go b/service/process/process_linux.go index df3fcc943..73121e18f 100644 --- a/service/process/process_linux.go +++ b/service/process/process_linux.go @@ -5,7 +5,7 @@ import ( "fmt" "syscall" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" ) const ( diff --git a/service/process/profile.go b/service/process/profile.go index 53599913f..e8c766eea 100644 --- a/service/process/profile.go +++ b/service/process/profile.go @@ -7,7 +7,7 @@ import ( "runtime" "strings" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/profile" ) diff --git a/service/process/special.go b/service/process/special.go index 5733c2ba2..937537eea 100644 --- a/service/process/special.go +++ b/service/process/special.go @@ -6,7 +6,7 @@ import ( "golang.org/x/sync/singleflight" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/network/socket" "github.com/safing/portmaster/service/profile" ) diff --git a/service/process/tags/appimage_unix.go b/service/process/tags/appimage_unix.go index 1e1bd2598..cae7365c2 100644 --- a/service/process/tags/appimage_unix.go +++ b/service/process/tags/appimage_unix.go @@ -7,7 +7,7 @@ import ( "regexp" "strings" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/process" "github.com/safing/portmaster/service/profile" "github.com/safing/portmaster/service/profile/binmeta" diff --git a/service/process/tags/svchost_windows.go b/service/process/tags/svchost_windows.go index 83087cbc2..236031b98 100644 --- a/service/process/tags/svchost_windows.go +++ b/service/process/tags/svchost_windows.go @@ -5,8 +5,8 @@ import ( "fmt" "strings" - "github.com/safing/portbase/log" - "github.com/safing/portbase/utils/osdetail" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/utils/osdetail" "github.com/safing/portmaster/service/process" "github.com/safing/portmaster/service/profile" "github.com/safing/portmaster/service/profile/binmeta" diff --git a/service/process/tags/winstore_windows.go b/service/process/tags/winstore_windows.go index e41995c8c..aa760eac2 100644 --- a/service/process/tags/winstore_windows.go +++ b/service/process/tags/winstore_windows.go @@ -4,8 +4,8 @@ import ( "os" "strings" - "github.com/safing/portbase/log" - "github.com/safing/portbase/utils" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/utils" "github.com/safing/portmaster/service/process" "github.com/safing/portmaster/service/profile" "github.com/safing/portmaster/service/profile/binmeta" diff --git a/service/profile/api.go b/service/profile/api.go index 048348562..60484e10e 100644 --- a/service/profile/api.go +++ b/service/profile/api.go @@ -7,9 +7,9 @@ import ( "path/filepath" "strings" - "github.com/safing/portbase/api" - "github.com/safing/portbase/formats/dsd" - "github.com/safing/portbase/utils" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/base/utils" "github.com/safing/portmaster/service/profile/binmeta" ) diff --git a/service/profile/binmeta/icon.go b/service/profile/binmeta/icon.go index 64ab6e43b..3f240051b 100644 --- a/service/profile/binmeta/icon.go +++ b/service/profile/binmeta/icon.go @@ -9,8 +9,8 @@ import ( "github.com/vincent-petithory/dataurl" "golang.org/x/exp/slices" - "github.com/safing/portbase/database" - "github.com/safing/portbase/database/record" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/record" ) // Icon describes an icon. diff --git a/service/profile/binmeta/icons.go b/service/profile/binmeta/icons.go index 3abe6ce85..73a8e4803 100644 --- a/service/profile/binmeta/icons.go +++ b/service/profile/binmeta/icons.go @@ -11,7 +11,7 @@ import ( "path/filepath" "strings" - "github.com/safing/portbase/api" + "github.com/safing/portmaster/base/api" ) // ProfileIconStoragePath defines the location where profile icons are stored. diff --git a/service/profile/config-update.go b/service/profile/config-update.go index 3c31603c0..eebf0b94b 100644 --- a/service/profile/config-update.go +++ b/service/profile/config-update.go @@ -6,7 +6,7 @@ import ( "sync" "time" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/service/intel/filterlists" "github.com/safing/portmaster/service/profile/endpoints" ) diff --git a/service/profile/config.go b/service/profile/config.go index a2b5da0a6..c3dbb26ae 100644 --- a/service/profile/config.go +++ b/service/profile/config.go @@ -3,7 +3,7 @@ package profile import ( "strings" - "github.com/safing/portbase/config" + "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/service/profile/endpoints" "github.com/safing/portmaster/service/status" "github.com/safing/portmaster/spn/access/account" diff --git a/service/profile/database.go b/service/profile/database.go index a9f927a14..ce0633cef 100644 --- a/service/profile/database.go +++ b/service/profile/database.go @@ -5,11 +5,11 @@ import ( "errors" "strings" - "github.com/safing/portbase/config" - "github.com/safing/portbase/database" - "github.com/safing/portbase/database/query" - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/log" ) // Database paths: diff --git a/service/profile/fingerprint.go b/service/profile/fingerprint.go index 56e921724..2185ac33e 100644 --- a/service/profile/fingerprint.go +++ b/service/profile/fingerprint.go @@ -8,7 +8,7 @@ import ( "golang.org/x/exp/slices" "github.com/safing/jess/lhash" - "github.com/safing/portbase/container" + "github.com/safing/portmaster/base/container" ) // # Matching and Scores diff --git a/service/profile/framework.go b/service/profile/framework.go index 4eae57c7e..a7fb2aac3 100644 --- a/service/profile/framework.go +++ b/service/profile/framework.go @@ -9,7 +9,7 @@ package profile // "regexp" // "strings" // -// "github.com/safing/portbase/log" +// "github.com/safing/portmaster/base/log" // ) // // type Framework struct { diff --git a/service/profile/get.go b/service/profile/get.go index 565011800..36e3e0d9f 100644 --- a/service/profile/get.go +++ b/service/profile/get.go @@ -8,11 +8,11 @@ import ( "strings" "sync" - "github.com/safing/portbase/database" - "github.com/safing/portbase/database/query" - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/log" - "github.com/safing/portbase/notifications" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/notifications" ) var getProfileLock sync.Mutex diff --git a/service/profile/merge.go b/service/profile/merge.go index 5e9951827..cc3755dbf 100644 --- a/service/profile/merge.go +++ b/service/profile/merge.go @@ -6,7 +6,7 @@ import ( "sync" "time" - "github.com/safing/portbase/database/record" + "github.com/safing/portmaster/base/database/record" "github.com/safing/portmaster/service/profile/binmeta" ) diff --git a/service/profile/meta.go b/service/profile/meta.go index de7c945c7..a2ef73f97 100644 --- a/service/profile/meta.go +++ b/service/profile/meta.go @@ -5,7 +5,7 @@ import ( "sync" "time" - "github.com/safing/portbase/database/record" + "github.com/safing/portmaster/base/database/record" ) // ProfilesMetadata holds metadata about all profiles that are not fit to be diff --git a/service/profile/migrations.go b/service/profile/migrations.go index 5eb943131..dfa6bf821 100644 --- a/service/profile/migrations.go +++ b/service/profile/migrations.go @@ -7,10 +7,10 @@ import ( "github.com/hashicorp/go-version" - "github.com/safing/portbase/database" - "github.com/safing/portbase/database/migration" - "github.com/safing/portbase/database/query" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/migration" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/profile/binmeta" ) diff --git a/service/profile/module.go b/service/profile/module.go index 4465750d0..aaca99d91 100644 --- a/service/profile/module.go +++ b/service/profile/module.go @@ -5,11 +5,11 @@ import ( "fmt" "os" - "github.com/safing/portbase/database" - "github.com/safing/portbase/database/migration" - "github.com/safing/portbase/dataroot" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/migration" + "github.com/safing/portmaster/base/dataroot" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" _ "github.com/safing/portmaster/service/core/base" "github.com/safing/portmaster/service/profile/binmeta" "github.com/safing/portmaster/service/updates" diff --git a/service/profile/profile-layered-provider.go b/service/profile/profile-layered-provider.go index 81d54c4bd..64ba7bca9 100644 --- a/service/profile/profile-layered-provider.go +++ b/service/profile/profile-layered-provider.go @@ -4,9 +4,9 @@ import ( "errors" "strings" - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/log" - "github.com/safing/portbase/runtime" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/runtime" ) const ( diff --git a/service/profile/profile-layered.go b/service/profile/profile-layered.go index 2635aed5b..fcb5b22f2 100644 --- a/service/profile/profile-layered.go +++ b/service/profile/profile-layered.go @@ -5,10 +5,10 @@ import ( "sync" "sync/atomic" - "github.com/safing/portbase/config" - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/log" - "github.com/safing/portbase/runtime" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/runtime" "github.com/safing/portmaster/service/intel" "github.com/safing/portmaster/service/profile/endpoints" ) diff --git a/service/profile/profile.go b/service/profile/profile.go index a11f7dd95..97bdca8cb 100644 --- a/service/profile/profile.go +++ b/service/profile/profile.go @@ -11,10 +11,10 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/config" - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/log" - "github.com/safing/portbase/utils" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/utils" "github.com/safing/portmaster/service/intel/filterlists" "github.com/safing/portmaster/service/profile/binmeta" "github.com/safing/portmaster/service/profile/endpoints" diff --git a/service/profile/special.go b/service/profile/special.go index 6e95e2789..55b466ece 100644 --- a/service/profile/special.go +++ b/service/profile/special.go @@ -3,7 +3,7 @@ package profile import ( "time" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" ) const ( diff --git a/service/resolver/api.go b/service/resolver/api.go index 727034624..c16f031ec 100644 --- a/service/resolver/api.go +++ b/service/resolver/api.go @@ -3,8 +3,8 @@ package resolver import ( "net/http" - "github.com/safing/portbase/api" - "github.com/safing/portbase/database/record" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/database/record" ) func registerAPI() error { diff --git a/service/resolver/config.go b/service/resolver/config.go index b5538d7d4..061847717 100644 --- a/service/resolver/config.go +++ b/service/resolver/config.go @@ -5,7 +5,7 @@ import ( "fmt" "strings" - "github.com/safing/portbase/config" + "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/status" ) diff --git a/service/resolver/failing.go b/service/resolver/failing.go index 2f1ff87b8..c1e347b5d 100644 --- a/service/resolver/failing.go +++ b/service/resolver/failing.go @@ -4,8 +4,8 @@ import ( "context" "time" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/service/netenv" ) diff --git a/service/resolver/ipinfo.go b/service/resolver/ipinfo.go index 6c1eed317..89cf9297c 100644 --- a/service/resolver/ipinfo.go +++ b/service/resolver/ipinfo.go @@ -6,8 +6,8 @@ import ( "sync" "time" - "github.com/safing/portbase/database" - "github.com/safing/portbase/database/record" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/record" ) const ( diff --git a/service/resolver/main.go b/service/resolver/main.go index 693797b55..f50dcc226 100644 --- a/service/resolver/main.go +++ b/service/resolver/main.go @@ -10,10 +10,10 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" - "github.com/safing/portbase/notifications" - "github.com/safing/portbase/utils/debug" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/notifications" + "github.com/safing/portmaster/base/utils/debug" _ "github.com/safing/portmaster/service/core/base" "github.com/safing/portmaster/service/intel" "github.com/safing/portmaster/service/netenv" diff --git a/service/resolver/metrics.go b/service/resolver/metrics.go index f118d0751..02ce9897b 100644 --- a/service/resolver/metrics.go +++ b/service/resolver/metrics.go @@ -5,9 +5,9 @@ import ( "sync/atomic" "time" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" - "github.com/safing/portbase/notifications" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/notifications" ) var ( diff --git a/service/resolver/namerecord.go b/service/resolver/namerecord.go index e3db05e06..075b53287 100644 --- a/service/resolver/namerecord.go +++ b/service/resolver/namerecord.go @@ -6,11 +6,11 @@ import ( "fmt" "sync" - "github.com/safing/portbase/api" - "github.com/safing/portbase/database" - "github.com/safing/portbase/database/query" - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/log" ) const ( diff --git a/service/resolver/resolve.go b/service/resolver/resolve.go index fe3e11ff9..1f85e59ef 100644 --- a/service/resolver/resolve.go +++ b/service/resolver/resolve.go @@ -12,8 +12,8 @@ import ( "github.com/miekg/dns" "golang.org/x/net/publicsuffix" - "github.com/safing/portbase/database" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/netenv" ) diff --git a/service/resolver/resolver-env.go b/service/resolver/resolver-env.go index 01f58ea70..808005968 100644 --- a/service/resolver/resolver-env.go +++ b/service/resolver/resolver-env.go @@ -8,7 +8,7 @@ import ( "github.com/miekg/dns" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/network/netutils" ) diff --git a/service/resolver/resolver-https.go b/service/resolver/resolver-https.go index ed04bf92d..46bc2ecfd 100644 --- a/service/resolver/resolver-https.go +++ b/service/resolver/resolver-https.go @@ -13,7 +13,7 @@ import ( "github.com/miekg/dns" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/netenv" ) diff --git a/service/resolver/resolver-mdns.go b/service/resolver/resolver-mdns.go index 17f034c85..cb4ba62c4 100644 --- a/service/resolver/resolver-mdns.go +++ b/service/resolver/resolver-mdns.go @@ -11,7 +11,7 @@ import ( "github.com/miekg/dns" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/network/netutils" ) diff --git a/service/resolver/resolver-plain.go b/service/resolver/resolver-plain.go index 56f85458c..66100c63d 100644 --- a/service/resolver/resolver-plain.go +++ b/service/resolver/resolver-plain.go @@ -8,7 +8,7 @@ import ( "github.com/miekg/dns" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/netenv" ) diff --git a/service/resolver/resolver-tcp.go b/service/resolver/resolver-tcp.go index 271d8808b..d94a40e3b 100644 --- a/service/resolver/resolver-tcp.go +++ b/service/resolver/resolver-tcp.go @@ -12,7 +12,7 @@ import ( "github.com/miekg/dns" "github.com/tevino/abool" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/netenv" ) diff --git a/service/resolver/resolver.go b/service/resolver/resolver.go index 3474fd305..0046ab8ea 100644 --- a/service/resolver/resolver.go +++ b/service/resolver/resolver.go @@ -10,7 +10,7 @@ import ( "github.com/miekg/dns" "github.com/tevino/abool" - "github.com/safing/portbase/utils" + "github.com/safing/portmaster/base/utils" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/network/netutils" ) diff --git a/service/resolver/resolver_test.go b/service/resolver/resolver_test.go index 397a914cd..5292155b1 100644 --- a/service/resolver/resolver_test.go +++ b/service/resolver/resolver_test.go @@ -10,7 +10,7 @@ import ( "github.com/miekg/dns" "github.com/stretchr/testify/assert" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" ) var ( diff --git a/service/resolver/resolvers.go b/service/resolver/resolvers.go index 93edf2a16..055aa6311 100644 --- a/service/resolver/resolvers.go +++ b/service/resolver/resolvers.go @@ -13,8 +13,8 @@ import ( "github.com/miekg/dns" "golang.org/x/net/publicsuffix" - "github.com/safing/portbase/log" - "github.com/safing/portbase/utils" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/utils" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/network/netutils" ) diff --git a/service/resolver/reverse.go b/service/resolver/reverse.go index f8abb9623..17bccc76e 100644 --- a/service/resolver/reverse.go +++ b/service/resolver/reverse.go @@ -7,7 +7,7 @@ import ( "github.com/miekg/dns" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" ) // ResolveIPAndValidate finds (reverse DNS), validates (forward DNS) and returns the domain name assigned to the given IP. diff --git a/service/resolver/reverse_test.go b/service/resolver/reverse_test.go index 421df6725..5b190f6e5 100644 --- a/service/resolver/reverse_test.go +++ b/service/resolver/reverse_test.go @@ -4,7 +4,7 @@ import ( "context" "testing" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" ) func testReverse(t *testing.T, ip, result, expectedErr string) { diff --git a/service/resolver/rrcache.go b/service/resolver/rrcache.go index 36b46e318..3a64dfa6c 100644 --- a/service/resolver/rrcache.go +++ b/service/resolver/rrcache.go @@ -8,7 +8,7 @@ import ( "github.com/miekg/dns" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/nameserver/nsutil" "github.com/safing/portmaster/service/netenv" ) diff --git a/service/resolver/scopes.go b/service/resolver/scopes.go index ac1391b1e..44c07d8a4 100644 --- a/service/resolver/scopes.go +++ b/service/resolver/scopes.go @@ -7,7 +7,7 @@ import ( "github.com/miekg/dns" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/netenv" ) diff --git a/service/status/module.go b/service/status/module.go index d10d51dc1..bb6d29fb3 100644 --- a/service/status/module.go +++ b/service/status/module.go @@ -4,8 +4,8 @@ import ( "context" "fmt" - "github.com/safing/portbase/modules" - "github.com/safing/portbase/utils/debug" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/utils/debug" "github.com/safing/portmaster/service/netenv" ) diff --git a/service/status/provider.go b/service/status/provider.go index 5130560eb..a8707e3b4 100644 --- a/service/status/provider.go +++ b/service/status/provider.go @@ -1,8 +1,8 @@ package status import ( - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/runtime" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/runtime" "github.com/safing/portmaster/service/netenv" ) diff --git a/service/status/records.go b/service/status/records.go index 56f19e5f6..a094bdb5c 100644 --- a/service/status/records.go +++ b/service/status/records.go @@ -3,7 +3,7 @@ package status import ( "sync" - "github.com/safing/portbase/database/record" + "github.com/safing/portmaster/base/database/record" "github.com/safing/portmaster/service/netenv" ) diff --git a/service/status/security_level.go b/service/status/security_level.go index 46641fc2f..79d27d265 100644 --- a/service/status/security_level.go +++ b/service/status/security_level.go @@ -1,6 +1,6 @@ package status -import "github.com/safing/portbase/config" +import "github.com/safing/portmaster/base/config" // MigrateSecurityLevelToBoolean migrates a security level (int) option value to a boolean option value. func MigrateSecurityLevelToBoolean(option *config.Option, value any) any { diff --git a/service/sync/module.go b/service/sync/module.go index 0c6ebc633..e6d43142f 100644 --- a/service/sync/module.go +++ b/service/sync/module.go @@ -1,8 +1,8 @@ package sync import ( - "github.com/safing/portbase/database" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/modules" ) var ( diff --git a/service/sync/profile.go b/service/sync/profile.go index 22a6472bb..cb75d37da 100644 --- a/service/sync/profile.go +++ b/service/sync/profile.go @@ -10,9 +10,9 @@ import ( "github.com/vincent-petithory/dataurl" - "github.com/safing/portbase/api" - "github.com/safing/portbase/config" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/profile" "github.com/safing/portmaster/service/profile/binmeta" ) diff --git a/service/sync/setting_single.go b/service/sync/setting_single.go index 8911d6e43..9fd9c9a6d 100644 --- a/service/sync/setting_single.go +++ b/service/sync/setting_single.go @@ -7,9 +7,9 @@ import ( "net/http" "time" - "github.com/safing/portbase/api" - "github.com/safing/portbase/config" - "github.com/safing/portbase/formats/dsd" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/service/profile" ) diff --git a/service/sync/settings.go b/service/sync/settings.go index 4e640d099..4d39b2874 100644 --- a/service/sync/settings.go +++ b/service/sync/settings.go @@ -8,8 +8,8 @@ import ( "strings" "time" - "github.com/safing/portbase/api" - "github.com/safing/portbase/config" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/service/profile" ) diff --git a/service/sync/util.go b/service/sync/util.go index 3d95b938b..bbd09e350 100644 --- a/service/sync/util.go +++ b/service/sync/util.go @@ -9,9 +9,9 @@ import ( yaml "gopkg.in/yaml.v3" "github.com/safing/jess/filesig" - "github.com/safing/portbase/api" - "github.com/safing/portbase/container" - "github.com/safing/portbase/formats/dsd" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/formats/dsd" ) // Type is the type of an export. diff --git a/service/ui/api.go b/service/ui/api.go index 9eb5f6e1c..5e57dfe5c 100644 --- a/service/ui/api.go +++ b/service/ui/api.go @@ -1,8 +1,8 @@ package ui import ( - "github.com/safing/portbase/api" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/log" ) func registerAPIEndpoints() error { diff --git a/service/ui/serve.go b/service/ui/serve.go index 1e9e58610..d8c9f5f26 100644 --- a/service/ui/serve.go +++ b/service/ui/serve.go @@ -13,11 +13,10 @@ import ( "github.com/spkg/zipfs" - "github.com/safing/portbase/api" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" - "github.com/safing/portbase/updater" - "github.com/safing/portbase/utils" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/updater" + "github.com/safing/portmaster/base/utils" "github.com/safing/portmaster/service/updates" ) diff --git a/service/updates/api.go b/service/updates/api.go index 57b75a3c6..5917ec4bb 100644 --- a/service/updates/api.go +++ b/service/updates/api.go @@ -10,9 +10,9 @@ import ( "github.com/ghodss/yaml" - "github.com/safing/portbase/api" - "github.com/safing/portbase/log" - "github.com/safing/portbase/utils" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/utils" ) const ( diff --git a/service/updates/config.go b/service/updates/config.go index c06e7793f..43f815ae9 100644 --- a/service/updates/config.go +++ b/service/updates/config.go @@ -5,8 +5,8 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/config" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/updates/helper" ) diff --git a/service/updates/export.go b/service/updates/export.go index 0f355d580..98ae4ba99 100644 --- a/service/updates/export.go +++ b/service/updates/export.go @@ -7,11 +7,11 @@ import ( "strings" "sync" - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/info" - "github.com/safing/portbase/log" - "github.com/safing/portbase/updater" - "github.com/safing/portbase/utils/debug" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/info" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/updater" + "github.com/safing/portmaster/base/utils/debug" "github.com/safing/portmaster/service/updates/helper" ) diff --git a/service/updates/get.go b/service/updates/get.go index c133ae1fb..4a35535fd 100644 --- a/service/updates/get.go +++ b/service/updates/get.go @@ -3,7 +3,7 @@ package updates import ( "path" - "github.com/safing/portbase/updater" + "github.com/safing/portmaster/base/updater" "github.com/safing/portmaster/service/updates/helper" ) diff --git a/service/updates/helper/electron.go b/service/updates/helper/electron.go index b6ebede55..4c8c4a07d 100644 --- a/service/updates/helper/electron.go +++ b/service/updates/helper/electron.go @@ -8,8 +8,8 @@ import ( "runtime" "strings" - "github.com/safing/portbase/log" - "github.com/safing/portbase/updater" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/updater" ) var pmElectronUpdate *updater.File diff --git a/service/updates/helper/indexes.go b/service/updates/helper/indexes.go index 8a272ea55..72457bc53 100644 --- a/service/updates/helper/indexes.go +++ b/service/updates/helper/indexes.go @@ -8,7 +8,7 @@ import ( "path/filepath" "github.com/safing/jess/filesig" - "github.com/safing/portbase/updater" + "github.com/safing/portmaster/base/updater" ) // Release Channel Configuration Keys. diff --git a/service/updates/helper/signing.go b/service/updates/helper/signing.go index 78ccab470..136b1970b 100644 --- a/service/updates/helper/signing.go +++ b/service/updates/helper/signing.go @@ -2,7 +2,7 @@ package helper import ( "github.com/safing/jess" - "github.com/safing/portbase/updater" + "github.com/safing/portmaster/base/updater" ) var ( diff --git a/service/updates/main.go b/service/updates/main.go index 218675b8a..a8d500388 100644 --- a/service/updates/main.go +++ b/service/updates/main.go @@ -9,11 +9,11 @@ import ( "runtime" "time" - "github.com/safing/portbase/database" - "github.com/safing/portbase/dataroot" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" - "github.com/safing/portbase/updater" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/dataroot" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/updater" "github.com/safing/portmaster/service/updates/helper" ) diff --git a/service/updates/notify.go b/service/updates/notify.go index d1e1622f2..01dde5234 100644 --- a/service/updates/notify.go +++ b/service/updates/notify.go @@ -6,7 +6,7 @@ import ( "sync/atomic" "time" - "github.com/safing/portbase/notifications" + "github.com/safing/portmaster/base/notifications" ) const ( diff --git a/service/updates/os_integration_linux.go b/service/updates/os_integration_linux.go index b20e3da46..cef0b9ef1 100644 --- a/service/updates/os_integration_linux.go +++ b/service/updates/os_integration_linux.go @@ -15,9 +15,9 @@ import ( "github.com/tevino/abool" "golang.org/x/exp/slices" - "github.com/safing/portbase/dataroot" - "github.com/safing/portbase/log" - "github.com/safing/portbase/utils/renameio" + "github.com/safing/portmaster/base/dataroot" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/utils/renameio" ) var ( diff --git a/service/updates/restart.go b/service/updates/restart.go index a3f911f85..b08fdf30c 100644 --- a/service/updates/restart.go +++ b/service/updates/restart.go @@ -9,8 +9,8 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" ) const ( diff --git a/service/updates/state.go b/service/updates/state.go index 2e0510574..3a1144b12 100644 --- a/service/updates/state.go +++ b/service/updates/state.go @@ -1,9 +1,9 @@ package updates import ( - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/runtime" - "github.com/safing/portbase/updater" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/runtime" + "github.com/safing/portmaster/base/updater" ) var pushRegistryStatusUpdate runtime.PushFunc diff --git a/service/updates/upgrader.go b/service/updates/upgrader.go index 03ec64db1..1dc64f16d 100644 --- a/service/updates/upgrader.go +++ b/service/updates/upgrader.go @@ -14,13 +14,13 @@ import ( processInfo "github.com/shirou/gopsutil/process" "github.com/tevino/abool" - "github.com/safing/portbase/dataroot" - "github.com/safing/portbase/info" - "github.com/safing/portbase/log" - "github.com/safing/portbase/notifications" - "github.com/safing/portbase/rng" - "github.com/safing/portbase/updater" - "github.com/safing/portbase/utils/renameio" + "github.com/safing/portmaster/base/dataroot" + "github.com/safing/portmaster/base/info" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/notifications" + "github.com/safing/portmaster/base/rng" + "github.com/safing/portmaster/base/updater" + "github.com/safing/portmaster/base/utils/renameio" "github.com/safing/portmaster/service/updates/helper" ) diff --git a/spn/access/api.go b/spn/access/api.go index c97370bc9..a45ac0a72 100644 --- a/spn/access/api.go +++ b/spn/access/api.go @@ -5,9 +5,9 @@ import ( "fmt" "net/http" - "github.com/safing/portbase/api" - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/access/account" ) diff --git a/spn/access/client.go b/spn/access/client.go index f22bb9e97..0e08c8880 100644 --- a/spn/access/client.go +++ b/spn/access/client.go @@ -8,9 +8,9 @@ import ( "sync" "time" - "github.com/safing/portbase/database" - "github.com/safing/portbase/formats/dsd" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/access/account" "github.com/safing/portmaster/spn/access/token" ) diff --git a/spn/access/database.go b/spn/access/database.go index be5ea95a8..6dd99b9db 100644 --- a/spn/access/database.go +++ b/spn/access/database.go @@ -7,8 +7,8 @@ import ( "sync" "time" - "github.com/safing/portbase/database" - "github.com/safing/portbase/database/record" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/record" "github.com/safing/portmaster/spn/access/account" ) diff --git a/spn/access/module.go b/spn/access/module.go index 3f935f33e..5779d52ed 100644 --- a/spn/access/module.go +++ b/spn/access/module.go @@ -8,9 +8,9 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/config" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/spn/access/account" "github.com/safing/portmaster/spn/access/token" "github.com/safing/portmaster/spn/conf" diff --git a/spn/access/notify.go b/spn/access/notify.go index 978a2f16d..3ed9f6471 100644 --- a/spn/access/notify.go +++ b/spn/access/notify.go @@ -5,8 +5,8 @@ import ( "strings" "time" - "github.com/safing/portbase/log" - "github.com/safing/portbase/notifications" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/notifications" ) const ( diff --git a/spn/access/op_auth.go b/spn/access/op_auth.go index 764c73c32..9ce642743 100644 --- a/spn/access/op_auth.go +++ b/spn/access/op_auth.go @@ -3,8 +3,8 @@ package access import ( "time" - "github.com/safing/portbase/container" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/access/token" "github.com/safing/portmaster/spn/terminal" ) diff --git a/spn/access/storage.go b/spn/access/storage.go index fcbb7edc8..6f6a924f3 100644 --- a/spn/access/storage.go +++ b/spn/access/storage.go @@ -6,11 +6,11 @@ import ( "fmt" "time" - "github.com/safing/portbase/database" - "github.com/safing/portbase/database/query" - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/formats/dsd" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/access/token" ) diff --git a/spn/access/token/module_test.go b/spn/access/token/module_test.go index bb79d76f0..b3cc49b8a 100644 --- a/spn/access/token/module_test.go +++ b/spn/access/token/module_test.go @@ -3,7 +3,7 @@ package token import ( "testing" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/service/core/pmtesting" ) diff --git a/spn/access/token/pblind.go b/spn/access/token/pblind.go index 71f137a3f..10e9bbfd2 100644 --- a/spn/access/token/pblind.go +++ b/spn/access/token/pblind.go @@ -13,8 +13,8 @@ import ( "github.com/mr-tron/base58" "github.com/rot256/pblind" - "github.com/safing/portbase/container" - "github.com/safing/portbase/formats/dsd" + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/formats/dsd" ) const pblindSecretSize = 32 diff --git a/spn/access/token/request_test.go b/spn/access/token/request_test.go index 7040672a8..e5525a412 100644 --- a/spn/access/token/request_test.go +++ b/spn/access/token/request_test.go @@ -4,7 +4,7 @@ import ( "testing" "time" - "github.com/safing/portbase/formats/dsd" + "github.com/safing/portmaster/base/formats/dsd" ) func TestFull(t *testing.T) { diff --git a/spn/access/token/scramble.go b/spn/access/token/scramble.go index df96bcc6f..c6ef236f3 100644 --- a/spn/access/token/scramble.go +++ b/spn/access/token/scramble.go @@ -7,7 +7,7 @@ import ( "github.com/mr-tron/base58" "github.com/safing/jess/lhash" - "github.com/safing/portbase/formats/dsd" + "github.com/safing/portmaster/base/formats/dsd" ) const ( diff --git a/spn/access/token/token.go b/spn/access/token/token.go index b93ed194a..9b615b1c4 100644 --- a/spn/access/token/token.go +++ b/spn/access/token/token.go @@ -8,7 +8,7 @@ import ( "github.com/mr-tron/base58" - "github.com/safing/portbase/container" + "github.com/safing/portmaster/base/container" ) // Token represents a token, consisting of a zone (name) and some data. diff --git a/spn/access/token/token_test.go b/spn/access/token/token_test.go index b132265a0..6a954556c 100644 --- a/spn/access/token/token_test.go +++ b/spn/access/token/token_test.go @@ -3,7 +3,7 @@ package token import ( "testing" - "github.com/safing/portbase/rng" + "github.com/safing/portmaster/base/rng" ) func TestToken(t *testing.T) { diff --git a/spn/access/zones.go b/spn/access/zones.go index 1f9c954bc..444ebf2da 100644 --- a/spn/access/zones.go +++ b/spn/access/zones.go @@ -9,7 +9,7 @@ import ( "github.com/tevino/abool" "github.com/safing/jess/lhash" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/access/token" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/terminal" diff --git a/spn/cabin/config-public.go b/spn/cabin/config-public.go index 4ae733ae6..057765a78 100644 --- a/spn/cabin/config-public.go +++ b/spn/cabin/config-public.go @@ -5,8 +5,8 @@ import ( "net" "os" - "github.com/safing/portbase/config" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/profile/endpoints" "github.com/safing/portmaster/spn/hub" diff --git a/spn/cabin/database.go b/spn/cabin/database.go index 410975302..b84418f97 100644 --- a/spn/cabin/database.go +++ b/spn/cabin/database.go @@ -4,8 +4,8 @@ import ( "errors" "fmt" - "github.com/safing/portbase/database" - "github.com/safing/portbase/database/record" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/record" "github.com/safing/portmaster/spn/hub" ) diff --git a/spn/cabin/identity.go b/spn/cabin/identity.go index 0be583cf9..d775c0162 100644 --- a/spn/cabin/identity.go +++ b/spn/cabin/identity.go @@ -8,9 +8,9 @@ import ( "github.com/safing/jess" "github.com/safing/jess/tools" - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/info" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/info" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/hub" ) diff --git a/spn/cabin/keys.go b/spn/cabin/keys.go index 67d203a4e..665c6d656 100644 --- a/spn/cabin/keys.go +++ b/spn/cabin/keys.go @@ -8,8 +8,8 @@ import ( "github.com/safing/jess" "github.com/safing/jess/tools" - "github.com/safing/portbase/log" - "github.com/safing/portbase/rng" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/rng" "github.com/safing/portmaster/spn/hub" ) diff --git a/spn/cabin/module.go b/spn/cabin/module.go index 8644502f5..3a1dd78e2 100644 --- a/spn/cabin/module.go +++ b/spn/cabin/module.go @@ -1,7 +1,7 @@ package cabin import ( - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/spn/conf" ) diff --git a/spn/cabin/verification.go b/spn/cabin/verification.go index 07a993ea7..4b1c0e48e 100644 --- a/spn/cabin/verification.go +++ b/spn/cabin/verification.go @@ -6,8 +6,8 @@ import ( "fmt" "github.com/safing/jess" - "github.com/safing/portbase/formats/dsd" - "github.com/safing/portbase/rng" + "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/base/rng" "github.com/safing/portmaster/spn/hub" ) diff --git a/spn/captain/api.go b/spn/captain/api.go index dcc412d82..ec43bc44e 100644 --- a/spn/captain/api.go +++ b/spn/captain/api.go @@ -4,10 +4,10 @@ import ( "errors" "fmt" - "github.com/safing/portbase/api" - "github.com/safing/portbase/database" - "github.com/safing/portbase/database/query" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/modules" ) const ( diff --git a/spn/captain/bootstrap.go b/spn/captain/bootstrap.go index c70961161..6516b937a 100644 --- a/spn/captain/bootstrap.go +++ b/spn/captain/bootstrap.go @@ -7,8 +7,8 @@ import ( "io/fs" "os" - "github.com/safing/portbase/formats/dsd" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/hub" "github.com/safing/portmaster/spn/navigator" diff --git a/spn/captain/client.go b/spn/captain/client.go index b30e4e98e..b1a81d413 100644 --- a/spn/captain/client.go +++ b/spn/captain/client.go @@ -8,8 +8,8 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/log" - "github.com/safing/portbase/notifications" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/network/netutils" "github.com/safing/portmaster/spn/access" diff --git a/spn/captain/config.go b/spn/captain/config.go index 09e6f4906..d6fa308d6 100644 --- a/spn/captain/config.go +++ b/spn/captain/config.go @@ -3,7 +3,7 @@ package captain import ( "sync" - "github.com/safing/portbase/config" + "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/service/profile" "github.com/safing/portmaster/service/profile/endpoints" "github.com/safing/portmaster/spn/conf" diff --git a/spn/captain/establish.go b/spn/captain/establish.go index 479098a50..ce322bd53 100644 --- a/spn/captain/establish.go +++ b/spn/captain/establish.go @@ -6,7 +6,7 @@ import ( "fmt" "time" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/docks" "github.com/safing/portmaster/spn/hub" diff --git a/spn/captain/intel.go b/spn/captain/intel.go index fe743c1bb..ff53bb4f2 100644 --- a/spn/captain/intel.go +++ b/spn/captain/intel.go @@ -6,8 +6,8 @@ import ( "os" "sync" - "github.com/safing/portbase/config" - "github.com/safing/portbase/updater" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/updater" "github.com/safing/portmaster/service/updates" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/hub" diff --git a/spn/captain/module.go b/spn/captain/module.go index 356eb199d..9538f5b7a 100644 --- a/spn/captain/module.go +++ b/spn/captain/module.go @@ -8,12 +8,12 @@ import ( "net/http" "time" - "github.com/safing/portbase/api" - "github.com/safing/portbase/config" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" - "github.com/safing/portbase/modules/subsystems" - "github.com/safing/portbase/rng" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/modules/subsystems" + "github.com/safing/portmaster/base/rng" "github.com/safing/portmaster/service/network/netutils" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/crew" diff --git a/spn/captain/navigation.go b/spn/captain/navigation.go index f080e0bc5..77e77c078 100644 --- a/spn/captain/navigation.go +++ b/spn/captain/navigation.go @@ -6,8 +6,8 @@ import ( "fmt" "time" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/service/intel" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/profile/endpoints" diff --git a/spn/captain/op_gossip.go b/spn/captain/op_gossip.go index e5fb43776..218e70e1c 100644 --- a/spn/captain/op_gossip.go +++ b/spn/captain/op_gossip.go @@ -3,9 +3,9 @@ package captain import ( "time" - "github.com/safing/portbase/container" - "github.com/safing/portbase/formats/varint" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/formats/varint" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/docks" "github.com/safing/portmaster/spn/hub" diff --git a/spn/captain/op_gossip_query.go b/spn/captain/op_gossip_query.go index aaadbc214..fbda083fb 100644 --- a/spn/captain/op_gossip_query.go +++ b/spn/captain/op_gossip_query.go @@ -5,9 +5,9 @@ import ( "strings" "time" - "github.com/safing/portbase/container" - "github.com/safing/portbase/formats/varint" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/formats/varint" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/docks" "github.com/safing/portmaster/spn/hub" diff --git a/spn/captain/op_publish.go b/spn/captain/op_publish.go index 178d1e885..3a377df8b 100644 --- a/spn/captain/op_publish.go +++ b/spn/captain/op_publish.go @@ -3,7 +3,7 @@ package captain import ( "time" - "github.com/safing/portbase/container" + "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/spn/cabin" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/docks" diff --git a/spn/captain/piers.go b/spn/captain/piers.go index b0c994bf7..c631e201b 100644 --- a/spn/captain/piers.go +++ b/spn/captain/piers.go @@ -5,7 +5,7 @@ import ( "errors" "fmt" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/intel" "github.com/safing/portmaster/service/network/netutils" "github.com/safing/portmaster/service/profile/endpoints" diff --git a/spn/captain/public.go b/spn/captain/public.go index 04710d9f7..441182d40 100644 --- a/spn/captain/public.go +++ b/spn/captain/public.go @@ -7,10 +7,10 @@ import ( "sort" "time" - "github.com/safing/portbase/database" - "github.com/safing/portbase/log" - "github.com/safing/portbase/metrics" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/metrics" + "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/spn/cabin" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/docks" diff --git a/spn/captain/status.go b/spn/captain/status.go index 99b6632c5..661cfcbe4 100644 --- a/spn/captain/status.go +++ b/spn/captain/status.go @@ -6,10 +6,10 @@ import ( "sync" "time" - "github.com/safing/portbase/config" - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/runtime" - "github.com/safing/portbase/utils/debug" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/runtime" + "github.com/safing/portmaster/base/utils/debug" "github.com/safing/portmaster/service/intel/geoip" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/navigator" diff --git a/spn/crew/connect.go b/spn/crew/connect.go index 4f376e448..ca11080c5 100644 --- a/spn/crew/connect.go +++ b/spn/crew/connect.go @@ -9,7 +9,7 @@ import ( "sync" "time" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/network" "github.com/safing/portmaster/service/profile/endpoints" "github.com/safing/portmaster/spn/access" diff --git a/spn/crew/metrics.go b/spn/crew/metrics.go index b9549d1e1..b11eb794d 100644 --- a/spn/crew/metrics.go +++ b/spn/crew/metrics.go @@ -5,8 +5,8 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/api" - "github.com/safing/portbase/metrics" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/metrics" ) var ( diff --git a/spn/crew/module.go b/spn/crew/module.go index 10d4ebed9..5bbf89294 100644 --- a/spn/crew/module.go +++ b/spn/crew/module.go @@ -3,7 +3,7 @@ package crew import ( "time" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/spn/terminal" ) diff --git a/spn/crew/op_connect.go b/spn/crew/op_connect.go index 0fc2174c8..228047b79 100644 --- a/spn/crew/op_connect.go +++ b/spn/crew/op_connect.go @@ -10,9 +10,9 @@ import ( "sync/atomic" "time" - "github.com/safing/portbase/container" - "github.com/safing/portbase/formats/dsd" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/network/netutils" "github.com/safing/portmaster/service/network/packet" "github.com/safing/portmaster/spn/conf" diff --git a/spn/crew/op_ping.go b/spn/crew/op_ping.go index 84ee4f6ea..2976fd611 100644 --- a/spn/crew/op_ping.go +++ b/spn/crew/op_ping.go @@ -4,9 +4,9 @@ import ( "crypto/subtle" "time" - "github.com/safing/portbase/container" - "github.com/safing/portbase/formats/dsd" - "github.com/safing/portbase/rng" + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/base/rng" "github.com/safing/portmaster/spn/terminal" ) diff --git a/spn/crew/sticky.go b/spn/crew/sticky.go index 598476fa2..4fcc39300 100644 --- a/spn/crew/sticky.go +++ b/spn/crew/sticky.go @@ -6,8 +6,8 @@ import ( "sync" "time" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/service/network" "github.com/safing/portmaster/service/network/packet" "github.com/safing/portmaster/spn/navigator" diff --git a/spn/docks/bandwidth_test.go b/spn/docks/bandwidth_test.go index 60101f1c5..c526fa9e9 100644 --- a/spn/docks/bandwidth_test.go +++ b/spn/docks/bandwidth_test.go @@ -6,8 +6,8 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/container" - "github.com/safing/portbase/formats/dsd" + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/spn/terminal" ) diff --git a/spn/docks/controller.go b/spn/docks/controller.go index 05e18e39c..4e8521757 100644 --- a/spn/docks/controller.go +++ b/spn/docks/controller.go @@ -1,7 +1,7 @@ package docks import ( - "github.com/safing/portbase/container" + "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/spn/terminal" ) diff --git a/spn/docks/crane.go b/spn/docks/crane.go index 34dab6d35..c7e65b41b 100644 --- a/spn/docks/crane.go +++ b/spn/docks/crane.go @@ -12,10 +12,10 @@ import ( "github.com/tevino/abool" "github.com/safing/jess" - "github.com/safing/portbase/container" - "github.com/safing/portbase/formats/varint" - "github.com/safing/portbase/log" - "github.com/safing/portbase/rng" + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/formats/varint" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/rng" "github.com/safing/portmaster/spn/cabin" "github.com/safing/portmaster/spn/hub" "github.com/safing/portmaster/spn/ships" diff --git a/spn/docks/crane_establish.go b/spn/docks/crane_establish.go index 3fa26732d..71637e456 100644 --- a/spn/docks/crane_establish.go +++ b/spn/docks/crane_establish.go @@ -4,8 +4,8 @@ import ( "context" "time" - "github.com/safing/portbase/container" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/terminal" ) diff --git a/spn/docks/crane_init.go b/spn/docks/crane_init.go index 740f7cdb2..472f9643d 100644 --- a/spn/docks/crane_init.go +++ b/spn/docks/crane_init.go @@ -5,11 +5,11 @@ import ( "time" "github.com/safing/jess" - "github.com/safing/portbase/container" - "github.com/safing/portbase/formats/dsd" - "github.com/safing/portbase/formats/varint" - "github.com/safing/portbase/info" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/base/formats/varint" + "github.com/safing/portmaster/base/info" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/terminal" ) diff --git a/spn/docks/crane_terminal.go b/spn/docks/crane_terminal.go index 731bf9531..7ac506092 100644 --- a/spn/docks/crane_terminal.go +++ b/spn/docks/crane_terminal.go @@ -3,7 +3,7 @@ package docks import ( "net" - "github.com/safing/portbase/container" + "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/spn/hub" "github.com/safing/portmaster/spn/terminal" ) diff --git a/spn/docks/crane_verify.go b/spn/docks/crane_verify.go index 1f4e686d7..cb2c86620 100644 --- a/spn/docks/crane_verify.go +++ b/spn/docks/crane_verify.go @@ -6,8 +6,8 @@ import ( "fmt" "time" - "github.com/safing/portbase/container" - "github.com/safing/portbase/formats/varint" + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/spn/cabin" "github.com/safing/portmaster/spn/terminal" ) diff --git a/spn/docks/cranehooks.go b/spn/docks/cranehooks.go index 0355a4f7b..e85956299 100644 --- a/spn/docks/cranehooks.go +++ b/spn/docks/cranehooks.go @@ -3,7 +3,7 @@ package docks import ( "sync" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" ) var ( diff --git a/spn/docks/hub_import.go b/spn/docks/hub_import.go index 377164f28..c8f46d307 100644 --- a/spn/docks/hub_import.go +++ b/spn/docks/hub_import.go @@ -6,7 +6,7 @@ import ( "net" "sync" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/hub" "github.com/safing/portmaster/spn/ships" diff --git a/spn/docks/metrics.go b/spn/docks/metrics.go index 49df92bda..38fd3f1ee 100644 --- a/spn/docks/metrics.go +++ b/spn/docks/metrics.go @@ -7,8 +7,8 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/api" - "github.com/safing/portbase/metrics" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/metrics" ) var ( diff --git a/spn/docks/module.go b/spn/docks/module.go index 31a4da95c..f79b27cf3 100644 --- a/spn/docks/module.go +++ b/spn/docks/module.go @@ -6,8 +6,8 @@ import ( "fmt" "sync" - "github.com/safing/portbase/modules" - "github.com/safing/portbase/rng" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/rng" _ "github.com/safing/portmaster/spn/access" ) diff --git a/spn/docks/op_capacity.go b/spn/docks/op_capacity.go index a66ae6171..a4ca5b5bb 100644 --- a/spn/docks/op_capacity.go +++ b/spn/docks/op_capacity.go @@ -8,9 +8,9 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/container" - "github.com/safing/portbase/formats/dsd" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/terminal" ) diff --git a/spn/docks/op_expand.go b/spn/docks/op_expand.go index 4a96c7662..567d6fc47 100644 --- a/spn/docks/op_expand.go +++ b/spn/docks/op_expand.go @@ -8,7 +8,7 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/container" + "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/terminal" ) diff --git a/spn/docks/op_latency.go b/spn/docks/op_latency.go index 02c38f786..12e9b75eb 100644 --- a/spn/docks/op_latency.go +++ b/spn/docks/op_latency.go @@ -6,10 +6,10 @@ import ( "fmt" "time" - "github.com/safing/portbase/container" - "github.com/safing/portbase/formats/varint" - "github.com/safing/portbase/log" - "github.com/safing/portbase/rng" + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/formats/varint" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/rng" "github.com/safing/portmaster/spn/terminal" ) diff --git a/spn/docks/op_sync_state.go b/spn/docks/op_sync_state.go index 43530803f..e6f964611 100644 --- a/spn/docks/op_sync_state.go +++ b/spn/docks/op_sync_state.go @@ -4,8 +4,8 @@ import ( "context" "time" - "github.com/safing/portbase/container" - "github.com/safing/portbase/formats/dsd" + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/terminal" ) diff --git a/spn/docks/op_whoami.go b/spn/docks/op_whoami.go index baf5204ca..9f6ce8606 100644 --- a/spn/docks/op_whoami.go +++ b/spn/docks/op_whoami.go @@ -3,8 +3,8 @@ package docks import ( "time" - "github.com/safing/portbase/container" - "github.com/safing/portbase/formats/dsd" + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/spn/terminal" ) diff --git a/spn/docks/terminal_expansion.go b/spn/docks/terminal_expansion.go index 16895a83d..442e9bf5d 100644 --- a/spn/docks/terminal_expansion.go +++ b/spn/docks/terminal_expansion.go @@ -7,7 +7,7 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/container" + "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/spn/hub" "github.com/safing/portmaster/spn/terminal" ) diff --git a/spn/hub/database.go b/spn/hub/database.go index d4ca3f854..c4b114709 100644 --- a/spn/hub/database.go +++ b/spn/hub/database.go @@ -6,10 +6,10 @@ import ( "sync" "time" - "github.com/safing/portbase/database" - "github.com/safing/portbase/database/iterator" - "github.com/safing/portbase/database/query" - "github.com/safing/portbase/database/record" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/iterator" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" ) var ( diff --git a/spn/hub/hub.go b/spn/hub/hub.go index efc34cd0f..0caac39b7 100644 --- a/spn/hub/hub.go +++ b/spn/hub/hub.go @@ -9,8 +9,8 @@ import ( "golang.org/x/exp/slices" "github.com/safing/jess" - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/profile/endpoints" ) diff --git a/spn/hub/hub_test.go b/spn/hub/hub_test.go index 70cc5b160..8bd14ce90 100644 --- a/spn/hub/hub_test.go +++ b/spn/hub/hub_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/modules" _ "github.com/safing/portmaster/service/core/base" "github.com/safing/portmaster/service/core/pmtesting" ) diff --git a/spn/hub/update.go b/spn/hub/update.go index 4e280b1af..98ca680c3 100644 --- a/spn/hub/update.go +++ b/spn/hub/update.go @@ -7,10 +7,10 @@ import ( "github.com/safing/jess" "github.com/safing/jess/lhash" - "github.com/safing/portbase/container" - "github.com/safing/portbase/database" - "github.com/safing/portbase/formats/dsd" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/network/netutils" ) diff --git a/spn/hub/update_test.go b/spn/hub/update_test.go index 982f32065..0f8d667e3 100644 --- a/spn/hub/update_test.go +++ b/spn/hub/update_test.go @@ -5,7 +5,7 @@ import ( "testing" "github.com/safing/jess" - "github.com/safing/portbase/formats/dsd" + "github.com/safing/portmaster/base/formats/dsd" ) func TestHubUpdate(t *testing.T) { diff --git a/spn/navigator/api.go b/spn/navigator/api.go index 832d1126b..2da6bfa91 100644 --- a/spn/navigator/api.go +++ b/spn/navigator/api.go @@ -15,8 +15,8 @@ import ( "github.com/awalterschulze/gographviz" - "github.com/safing/portbase/api" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/docks" "github.com/safing/portmaster/spn/hub" ) diff --git a/spn/navigator/api_route.go b/spn/navigator/api_route.go index 4d8548411..bae4a27fc 100644 --- a/spn/navigator/api_route.go +++ b/spn/navigator/api_route.go @@ -11,8 +11,8 @@ import ( "text/tabwriter" "time" - "github.com/safing/portbase/api" - "github.com/safing/portbase/config" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/service/intel" "github.com/safing/portmaster/service/intel/geoip" "github.com/safing/portmaster/service/netenv" diff --git a/spn/navigator/database.go b/spn/navigator/database.go index b7ee8ae4d..7a4a88b3e 100644 --- a/spn/navigator/database.go +++ b/spn/navigator/database.go @@ -5,11 +5,11 @@ import ( "fmt" "strings" - "github.com/safing/portbase/database" - "github.com/safing/portbase/database/iterator" - "github.com/safing/portbase/database/query" - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/database/storage" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/iterator" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/database/storage" ) var mapDBController *database.Controller diff --git a/spn/navigator/intel.go b/spn/navigator/intel.go index d26733c1e..cd4501b40 100644 --- a/spn/navigator/intel.go +++ b/spn/navigator/intel.go @@ -6,7 +6,7 @@ import ( "golang.org/x/exp/slices" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/intel/geoip" "github.com/safing/portmaster/service/profile/endpoints" "github.com/safing/portmaster/spn/hub" diff --git a/spn/navigator/map.go b/spn/navigator/map.go index 006dfc138..6ec8bab78 100644 --- a/spn/navigator/map.go +++ b/spn/navigator/map.go @@ -5,8 +5,8 @@ import ( "sync" "time" - "github.com/safing/portbase/database" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/intel/geoip" "github.com/safing/portmaster/spn/docks" "github.com/safing/portmaster/spn/hub" diff --git a/spn/navigator/measurements.go b/spn/navigator/measurements.go index 571365cbf..f2784d1e0 100644 --- a/spn/navigator/measurements.go +++ b/spn/navigator/measurements.go @@ -5,8 +5,8 @@ import ( "sort" "time" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/spn/docks" "github.com/safing/portmaster/spn/terminal" ) diff --git a/spn/navigator/metrics.go b/spn/navigator/metrics.go index fe62020ea..e1f2d7425 100644 --- a/spn/navigator/metrics.go +++ b/spn/navigator/metrics.go @@ -7,8 +7,8 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/api" - "github.com/safing/portbase/metrics" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/metrics" ) var metricsRegistered = abool.New() diff --git a/spn/navigator/module.go b/spn/navigator/module.go index 9937ad619..3e8c6fc10 100644 --- a/spn/navigator/module.go +++ b/spn/navigator/module.go @@ -4,9 +4,9 @@ import ( "errors" "time" - "github.com/safing/portbase/config" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/service/intel/geoip" "github.com/safing/portmaster/spn/conf" ) diff --git a/spn/navigator/module_test.go b/spn/navigator/module_test.go index f55ea4e81..4433835fb 100644 --- a/spn/navigator/module_test.go +++ b/spn/navigator/module_test.go @@ -3,7 +3,7 @@ package navigator import ( "testing" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/core/pmtesting" ) diff --git a/spn/navigator/options.go b/spn/navigator/options.go index 05c93ea1d..3e2407fac 100644 --- a/spn/navigator/options.go +++ b/spn/navigator/options.go @@ -3,7 +3,7 @@ package navigator import ( "context" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/intel" "github.com/safing/portmaster/service/profile" "github.com/safing/portmaster/service/profile/endpoints" diff --git a/spn/navigator/pin.go b/spn/navigator/pin.go index 9e113ab46..3c2b99bf8 100644 --- a/spn/navigator/pin.go +++ b/spn/navigator/pin.go @@ -8,7 +8,7 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/intel" "github.com/safing/portmaster/service/intel/geoip" "github.com/safing/portmaster/spn/docks" diff --git a/spn/navigator/pin_export.go b/spn/navigator/pin_export.go index 85fd279ef..422a074f9 100644 --- a/spn/navigator/pin_export.go +++ b/spn/navigator/pin_export.go @@ -4,7 +4,7 @@ import ( "sync" "time" - "github.com/safing/portbase/database/record" + "github.com/safing/portmaster/base/database/record" "github.com/safing/portmaster/service/intel" "github.com/safing/portmaster/spn/hub" ) diff --git a/spn/navigator/region.go b/spn/navigator/region.go index a3798efe0..d3dcc3b98 100644 --- a/spn/navigator/region.go +++ b/spn/navigator/region.go @@ -4,7 +4,7 @@ import ( "context" "math" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/profile/endpoints" "github.com/safing/portmaster/spn/hub" ) diff --git a/spn/navigator/routing-profiles.go b/spn/navigator/routing-profiles.go index 9241c0726..55508eb86 100644 --- a/spn/navigator/routing-profiles.go +++ b/spn/navigator/routing-profiles.go @@ -1,7 +1,7 @@ package navigator import ( - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/profile" ) diff --git a/spn/navigator/update.go b/spn/navigator/update.go index 3fe170740..ee2ed1956 100644 --- a/spn/navigator/update.go +++ b/spn/navigator/update.go @@ -10,13 +10,13 @@ import ( "github.com/tevino/abool" "golang.org/x/exp/slices" - "github.com/safing/portbase/config" - "github.com/safing/portbase/database" - "github.com/safing/portbase/database/query" - "github.com/safing/portbase/database/record" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" - "github.com/safing/portbase/utils" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/utils" "github.com/safing/portmaster/service/intel/geoip" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/profile" diff --git a/spn/patrol/http.go b/spn/patrol/http.go index 391518c12..e1396fde6 100644 --- a/spn/patrol/http.go +++ b/spn/patrol/http.go @@ -9,8 +9,8 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/spn/conf" ) diff --git a/spn/patrol/module.go b/spn/patrol/module.go index 842c139cc..9a66cee5d 100644 --- a/spn/patrol/module.go +++ b/spn/patrol/module.go @@ -3,7 +3,7 @@ package patrol import ( "time" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/spn/conf" ) diff --git a/spn/ships/http.go b/spn/ships/http.go index 165ca9df4..1e92dfdb1 100644 --- a/spn/ships/http.go +++ b/spn/ships/http.go @@ -9,7 +9,7 @@ import ( "net/http" "time" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/hub" ) diff --git a/spn/ships/http_info.go b/spn/ships/http_info.go index 886f2127c..c4062a5c9 100644 --- a/spn/ships/http_info.go +++ b/spn/ships/http_info.go @@ -6,9 +6,9 @@ import ( "html/template" "net/http" - "github.com/safing/portbase/config" - "github.com/safing/portbase/info" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/info" + "github.com/safing/portmaster/base/log" ) var ( diff --git a/spn/ships/http_info_test.go b/spn/ships/http_info_test.go index a490dfcec..c82ea37cc 100644 --- a/spn/ships/http_info_test.go +++ b/spn/ships/http_info_test.go @@ -4,7 +4,7 @@ import ( "html/template" "testing" - "github.com/safing/portbase/config" + "github.com/safing/portmaster/base/config" ) func TestInfoPageTemplate(t *testing.T) { diff --git a/spn/ships/http_shared.go b/spn/ships/http_shared.go index c90504e1a..3ebc49485 100644 --- a/spn/ships/http_shared.go +++ b/spn/ships/http_shared.go @@ -9,7 +9,7 @@ import ( "sync" "time" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/conf" ) diff --git a/spn/ships/launch.go b/spn/ships/launch.go index 45a778340..cf1fc6097 100644 --- a/spn/ships/launch.go +++ b/spn/ships/launch.go @@ -5,7 +5,7 @@ import ( "fmt" "net" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/spn/hub" ) diff --git a/spn/ships/module.go b/spn/ships/module.go index d450185eb..01543ac27 100644 --- a/spn/ships/module.go +++ b/spn/ships/module.go @@ -1,7 +1,7 @@ package ships import ( - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/spn/conf" ) diff --git a/spn/ships/ship.go b/spn/ships/ship.go index 4bb39b0e3..03918609c 100644 --- a/spn/ships/ship.go +++ b/spn/ships/ship.go @@ -7,7 +7,7 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/hub" ) diff --git a/spn/ships/tcp.go b/spn/ships/tcp.go index 5ffd5b905..ffc6c6979 100644 --- a/spn/ships/tcp.go +++ b/spn/ships/tcp.go @@ -6,7 +6,7 @@ import ( "net" "time" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/hub" ) diff --git a/spn/sluice/module.go b/spn/sluice/module.go index 63f1d2e0f..6ca15af1b 100644 --- a/spn/sluice/module.go +++ b/spn/sluice/module.go @@ -1,8 +1,8 @@ package sluice import ( - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/spn/conf" ) diff --git a/spn/sluice/sluice.go b/spn/sluice/sluice.go index bc136b10c..6a3249f90 100644 --- a/spn/sluice/sluice.go +++ b/spn/sluice/sluice.go @@ -8,7 +8,7 @@ import ( "sync" "time" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/netenv" ) diff --git a/spn/terminal/control_flow.go b/spn/terminal/control_flow.go index e4d15ccfc..b8572dda0 100644 --- a/spn/terminal/control_flow.go +++ b/spn/terminal/control_flow.go @@ -7,8 +7,8 @@ import ( "sync/atomic" "time" - "github.com/safing/portbase/formats/varint" - "github.com/safing/portbase/modules" + "github.com/safing/portmaster/base/formats/varint" + "github.com/safing/portmaster/base/modules" ) // FlowControl defines the flow control interface. diff --git a/spn/terminal/errors.go b/spn/terminal/errors.go index 619bf1814..bc762bb3b 100644 --- a/spn/terminal/errors.go +++ b/spn/terminal/errors.go @@ -5,7 +5,7 @@ import ( "errors" "fmt" - "github.com/safing/portbase/formats/varint" + "github.com/safing/portmaster/base/formats/varint" ) // Error is a terminal error. diff --git a/spn/terminal/init.go b/spn/terminal/init.go index b99604249..3c6ce921a 100644 --- a/spn/terminal/init.go +++ b/spn/terminal/init.go @@ -4,9 +4,9 @@ import ( "context" "github.com/safing/jess" - "github.com/safing/portbase/container" - "github.com/safing/portbase/formats/dsd" - "github.com/safing/portbase/formats/varint" + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/spn/cabin" "github.com/safing/portmaster/spn/hub" ) diff --git a/spn/terminal/metrics.go b/spn/terminal/metrics.go index 0da0c326c..099bde6c1 100644 --- a/spn/terminal/metrics.go +++ b/spn/terminal/metrics.go @@ -5,8 +5,8 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/api" - "github.com/safing/portbase/metrics" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/metrics" ) var metricsRegistered = abool.New() diff --git a/spn/terminal/module.go b/spn/terminal/module.go index 178bc08c2..01cd6d066 100644 --- a/spn/terminal/module.go +++ b/spn/terminal/module.go @@ -4,8 +4,8 @@ import ( "flag" "time" - "github.com/safing/portbase/modules" - "github.com/safing/portbase/rng" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/rng" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/unit" ) diff --git a/spn/terminal/msg.go b/spn/terminal/msg.go index 8ca004890..764601e87 100644 --- a/spn/terminal/msg.go +++ b/spn/terminal/msg.go @@ -4,7 +4,7 @@ import ( "fmt" "runtime" - "github.com/safing/portbase/container" + "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/spn/unit" ) diff --git a/spn/terminal/msgtypes.go b/spn/terminal/msgtypes.go index df712618b..a7d244b34 100644 --- a/spn/terminal/msgtypes.go +++ b/spn/terminal/msgtypes.go @@ -1,8 +1,8 @@ package terminal import ( - "github.com/safing/portbase/container" - "github.com/safing/portbase/formats/varint" + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/formats/varint" ) /* diff --git a/spn/terminal/operation.go b/spn/terminal/operation.go index 100936ec7..23249be05 100644 --- a/spn/terminal/operation.go +++ b/spn/terminal/operation.go @@ -8,9 +8,9 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/container" - "github.com/safing/portbase/log" - "github.com/safing/portbase/utils" + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/utils" ) // Operation is an interface for all operations. diff --git a/spn/terminal/operation_counter.go b/spn/terminal/operation_counter.go index 59d175e03..1a732ce86 100644 --- a/spn/terminal/operation_counter.go +++ b/spn/terminal/operation_counter.go @@ -6,10 +6,10 @@ import ( "sync" "time" - "github.com/safing/portbase/container" - "github.com/safing/portbase/formats/dsd" - "github.com/safing/portbase/formats/varint" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/base/formats/varint" + "github.com/safing/portmaster/base/log" ) // CounterOpType is the type ID for the Counter Operation. diff --git a/spn/terminal/session.go b/spn/terminal/session.go index fa2d16956..f1fa1424d 100644 --- a/spn/terminal/session.go +++ b/spn/terminal/session.go @@ -7,7 +7,7 @@ import ( "sync/atomic" "time" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" ) const ( diff --git a/spn/terminal/terminal.go b/spn/terminal/terminal.go index bbccad2fa..dc0587564 100644 --- a/spn/terminal/terminal.go +++ b/spn/terminal/terminal.go @@ -9,10 +9,10 @@ import ( "github.com/tevino/abool" "github.com/safing/jess" - "github.com/safing/portbase/container" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" - "github.com/safing/portbase/rng" + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/rng" "github.com/safing/portmaster/spn/cabin" "github.com/safing/portmaster/spn/conf" ) diff --git a/spn/terminal/terminal_test.go b/spn/terminal/terminal_test.go index b458f696f..7d0a83434 100644 --- a/spn/terminal/terminal_test.go +++ b/spn/terminal/terminal_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/safing/portbase/container" + "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/spn/cabin" "github.com/safing/portmaster/spn/hub" ) diff --git a/spn/terminal/testing.go b/spn/terminal/testing.go index 22b12608c..1c59de690 100644 --- a/spn/terminal/testing.go +++ b/spn/terminal/testing.go @@ -4,8 +4,8 @@ import ( "context" "time" - "github.com/safing/portbase/container" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/cabin" "github.com/safing/portmaster/spn/hub" ) diff --git a/spn/unit/unit_debug.go b/spn/unit/unit_debug.go index 0ba053bd8..840ca479e 100644 --- a/spn/unit/unit_debug.go +++ b/spn/unit/unit_debug.go @@ -4,7 +4,7 @@ import ( "sync" "time" - "github.com/safing/portbase/log" + "github.com/safing/portmaster/base/log" ) // UnitDebugger is used to debug unit leaks. From 3533f819a3a6dae1217baeca8226cc1f9cbedda7 Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 5 Jun 2024 15:50:19 +0200 Subject: [PATCH 02/56] Add new simple module mgr --- service/mgr/doc.go | 2 + service/mgr/events.go | 175 +++++++++++++++++++++++++ service/mgr/manager.go | 164 ++++++++++++++++++++++++ service/mgr/module.go | 162 +++++++++++++++++++++++ service/mgr/worker.go | 285 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 788 insertions(+) create mode 100644 service/mgr/doc.go create mode 100644 service/mgr/events.go create mode 100644 service/mgr/manager.go create mode 100644 service/mgr/module.go create mode 100644 service/mgr/worker.go diff --git a/service/mgr/doc.go b/service/mgr/doc.go new file mode 100644 index 000000000..9b5fff7c5 --- /dev/null +++ b/service/mgr/doc.go @@ -0,0 +1,2 @@ +// Package mgr provides simple managing of flow control and logging. +package mgr diff --git a/service/mgr/events.go b/service/mgr/events.go new file mode 100644 index 000000000..da7c23e2d --- /dev/null +++ b/service/mgr/events.go @@ -0,0 +1,175 @@ +//nolint:structcheck,golint // TODO: Seems broken for generics. +package mgr + +import ( + "fmt" + "slices" + "sync" + "sync/atomic" +) + +// EventMgr is a simple event manager. +type EventMgr[T any] struct { + name string + mgr *Manager + lock sync.Mutex + + subs []*EventSubscription[T] + callbacks []*EventCallback[T] +} + +// EventSubscription is a subscription to an event. +type EventSubscription[T any] struct { + name string + events chan T + canceled atomic.Bool +} + +// EventCallback is a registered callback to an event. +type EventCallback[T any] struct { + name string + callback EventCallbackFunc[T] + canceled atomic.Bool +} + +// EventCallbackFunc defines the event callback function. +type EventCallbackFunc[T any] func(*WorkerCtx, T) (cancel bool, err error) + +// NewEventMgr returns a new event manager. +// It is easiest used as a public field on a struct, +// so that others can simply Subscribe() oder AddCallback(). +func NewEventMgr[T any](eventName string, mgr *Manager) *EventMgr[T] { + return &EventMgr[T]{ + name: eventName, + mgr: mgr, + } +} + +// Subscribe subscribes to events. +// The received events are shared among all subscribers and callbacks. +// Be sure to apply proper concurrency safeguards, if applicable. +func (em *EventMgr[T]) Subscribe(subscriberName string, chanSize int) *EventSubscription[T] { + em.lock.Lock() + defer em.lock.Unlock() + + es := &EventSubscription[T]{ + name: subscriberName, + events: make(chan T, chanSize), + } + + em.subs = append(em.subs, es) + return es +} + +// AddCallback adds a callback to executed on events. +// The received events are shared among all subscribers and callbacks. +// Be sure to apply proper concurrency safeguards, if applicable. +func (em *EventMgr[T]) AddCallback(callbackName string, callback EventCallbackFunc[T]) { + em.lock.Lock() + defer em.lock.Unlock() + + ec := &EventCallback[T]{ + name: callbackName, + callback: callback, + } + + em.callbacks = append(em.callbacks, ec) +} + +// Submit submits a new event. +func (em *EventMgr[T]) Submit(event T) { + em.lock.Lock() + defer em.lock.Unlock() + + var anyCanceled bool + + // Send to subscriptions. + for _, sub := range em.subs { + // Check if subcription was canceled. + if sub.canceled.Load() { + anyCanceled = true + continue + } + + // Submit via channel. + select { + case sub.events <- event: + default: + if em.mgr != nil { + em.mgr.Warn( + "event subscription channel overflow", + "event", em.name, + "subscriber", sub.name, + ) + } + } + } + + // Run callbacks. + for _, ec := range em.callbacks { + // Execute callback. + var ( + cancel bool + err error + ) + if em.mgr != nil { + // Prefer executing in worker. + wkrErr := em.mgr.Do("execute event callback", func(w *WorkerCtx) error { + cancel, err = ec.callback(w, event) //nolint:scopelint // Execution is within scope. + return nil + }) + if wkrErr != nil { + err = fmt.Errorf("callback execution failed: %w", wkrErr) + } + } else { + cancel, err = ec.callback(nil, event) + } + + // Handle error and cancelation. + if err != nil && em.mgr != nil { + em.mgr.Warn( + "event callback failed", + "event", em.name, + "callback", ec.name, + "err", err, + ) + } + if cancel { + ec.canceled.Store(true) + anyCanceled = true + } + } + + // If any canceled subscription/callback was seen, clean the slices. + if anyCanceled { + em.clean() + } +} + +// clean removes all canceled subscriptions and callbacks. +func (em *EventMgr[T]) clean() { + em.subs = slices.DeleteFunc[[]*EventSubscription[T], *EventSubscription[T]](em.subs, func(es *EventSubscription[T]) bool { + return es.canceled.Load() + }) + em.callbacks = slices.DeleteFunc[[]*EventCallback[T], *EventCallback[T]](em.callbacks, func(ec *EventCallback[T]) bool { + return ec.canceled.Load() + }) +} + +// Events returns a read channel for the events. +// The received events are shared among all subscribers and callbacks. +// Be sure to apply proper concurrency safeguards, if applicable. +func (es *EventSubscription[T]) Events() <-chan T { + return es.events +} + +// Cancel cancels the subscription. +// The events channel is not closed, but will not receive new events. +func (es *EventSubscription[T]) Cancel() { + es.canceled.Store(true) +} + +// Done returns whether the event subscription has been canceled. +func (es *EventSubscription[T]) Done() bool { + return es.canceled.Load() +} diff --git a/service/mgr/manager.go b/service/mgr/manager.go new file mode 100644 index 000000000..cd346c7fa --- /dev/null +++ b/service/mgr/manager.go @@ -0,0 +1,164 @@ +package mgr + +import ( + "context" + "log/slog" + "sync/atomic" + "time" +) + +// Manager manages workers. +type Manager struct { + name string + logger *slog.Logger + + ctx context.Context + cancelCtx context.CancelFunc + + workerCnt atomic.Int32 + workersDone chan struct{} +} + +// New returns a new manager. +func New(name string) *Manager { + return NewWithContext(context.Background(), name) +} + +// NewWithContext returns a new manager that uses the given context. +func NewWithContext(ctx context.Context, name string) *Manager { + return newManager(ctx, name, "manager") +} + +func newManager(ctx context.Context, name string, logNameKey string) *Manager { + m := &Manager{ + name: name, + logger: slog.Default().With(logNameKey, name), + workersDone: make(chan struct{}), + } + m.ctx, m.cancelCtx = context.WithCancel(ctx) + return m +} + +// Name returns the manager name. +func (m *Manager) Name() string { + return m.name +} + +// Ctx returns the worker context. +func (m *Manager) Ctx() context.Context { + return m.ctx +} + +// Cancel cancels the worker context. +func (m *Manager) Cancel() { + m.cancelCtx() +} + +// Done returns the context Done channel. +func (m *Manager) Done() <-chan struct{} { + return m.ctx.Done() +} + +// IsDone checks whether the manager context is done. +func (m *Manager) IsDone() bool { + return m.ctx.Err() != nil +} + +// LogEnabled reports whether the logger emits log records at the given level. +// The manager context is automatically supplied. +func (m *Manager) LogEnabled(level slog.Level) bool { + return m.logger.Enabled(m.ctx, level) +} + +// Debug logs at LevelDebug. +// The manager context is automatically supplied. +func (m *Manager) Debug(msg string, args ...any) { + m.logger.DebugContext(m.ctx, msg, args...) +} + +// Info logs at LevelInfo. +// The manager context is automatically supplied. +func (m *Manager) Info(msg string, args ...any) { + m.logger.InfoContext(m.ctx, msg, args...) +} + +// Warn logs at LevelWarn. +// The manager context is automatically supplied. +func (m *Manager) Warn(msg string, args ...any) { + m.logger.WarnContext(m.ctx, msg, args...) +} + +// Error logs at LevelError. +// The manager context is automatically supplied. +func (m *Manager) Error(msg string, args ...any) { + m.logger.ErrorContext(m.ctx, msg, args...) +} + +// Log emits a log record with the current time and the given level and message. +// The manager context is automatically supplied. +func (m *Manager) Log(level slog.Level, msg string, args ...any) { + m.logger.Log(m.ctx, level, msg, args...) +} + +// LogAttrs is a more efficient version of Log() that accepts only Attrs. +// The manager context is automatically supplied. +func (m *Manager) LogAttrs(level slog.Level, msg string, attrs ...slog.Attr) { + m.logger.LogAttrs(m.ctx, level, msg, attrs...) +} + +// WaitForWorkers waits for all workers of this manager to be done. +// The default maximum waiting time is one minute. +func (m *Manager) WaitForWorkers(max time.Duration) (done bool) { + // Return immediately if there are no workers. + if m.workerCnt.Load() == 0 { + return true + } + + // Setup timers. + reCheckDuration := 100 * time.Millisecond + if max <= 0 { + max = time.Minute + } + reCheck := time.NewTimer(reCheckDuration) + maxWait := time.NewTimer(max) + defer reCheck.Stop() + defer maxWait.Stop() + + // Wait for workers to finish, plus check the count in intervals. + for { + if m.workerCnt.Load() == 0 { + return true + } + + select { + case <-m.workersDone: + return true + + case <-reCheck.C: + // Check worker count again. + // This is a dead simple and effective way to avoid all the channel race conditions. + reCheckDuration *= 2 + reCheck.Reset(reCheckDuration) + + case <-maxWait.C: + return m.workerCnt.Load() == 0 + } + } +} + +func (m *Manager) workerStart() { + m.workerCnt.Add(1) +} + +func (m *Manager) workerDone() { + if m.workerCnt.Add(-1) == 0 { + // Notify all waiters. + for { + select { + case m.workersDone <- struct{}{}: + default: + return + } + } + } +} diff --git a/service/mgr/module.go b/service/mgr/module.go new file mode 100644 index 000000000..019138af5 --- /dev/null +++ b/service/mgr/module.go @@ -0,0 +1,162 @@ +package mgr + +import ( + "context" + "errors" + "fmt" + "reflect" + "strings" + "sync" +) + +// Group describes a group of modules. +type Group struct { + modules []*groupModule + + ctx context.Context + cancelCtx context.CancelFunc + ctxLock sync.Mutex +} + +type groupModule struct { + module Module + mgr *Manager +} + +// Module is an manage-able instance of some component. +type Module interface { + Start(mgr *Manager) error + Stop(mgr *Manager) error +} + +// NewGroup returns a new group of modules. +func NewGroup(modules ...Module) *Group { + // Create group. + g := &Group{ + modules: make([]*groupModule, 0, len(modules)), + } + g.initGroupContext() + + // Initialize groups modules. + for _, m := range modules { + // Skip non-values. + switch { + case m == nil: + // Skip nil values to allow for cleaner code. + continue + case reflect.ValueOf(m).IsNil(): + // If nil values are given via a struct, they are will be interfaces to a + // nil type. Ignore these too. + continue + } + + // Add module to group. + g.modules = append(g.modules, &groupModule{ + module: m, + mgr: newManager(g.ctx, makeModuleName(m), "module"), + }) + } + + return g +} + +// Start starts all modules in the group in the defined order. +// If a module fails to start, itself and all previous modules +// will be stopped in the reverse order. +func (g *Group) Start() error { + g.initGroupContext() + + for i, m := range g.modules { + err := m.module.Start(m.mgr) + if err != nil { + g.stopFrom(i) + return fmt.Errorf("failed to start %s: %w", makeModuleName(m.module), err) + } + m.mgr.Info("started") + } + return nil +} + +// Stop stops all modules in the group in the reverse order. +func (g *Group) Stop() (ok bool) { + return g.stopFrom(len(g.modules) - 1) +} + +func (g *Group) stopFrom(index int) (ok bool) { + ok = true + for i := index; i >= 0; i-- { + m := g.modules[i] + err := m.module.Stop(m.mgr) + if err != nil { + m.mgr.Error("failed to stop", "err", err) + ok = false + } + m.mgr.Cancel() + if m.mgr.WaitForWorkers(0) { + m.mgr.Info("stopped") + } else { + ok = false + m.mgr.Error( + "failed to stop", + "err", "timed out", + "workerCnt", m.mgr.workerCnt.Load(), + ) + } + } + + g.stopGroupContext() + return +} + +func (g *Group) initGroupContext() { + g.ctxLock.Lock() + defer g.ctxLock.Unlock() + + g.ctx, g.cancelCtx = context.WithCancel(context.Background()) +} + +func (g *Group) stopGroupContext() { + g.ctxLock.Lock() + defer g.ctxLock.Unlock() + + g.cancelCtx() +} + +// Done returns the context Done channel. +func (g *Group) Done() <-chan struct{} { + g.ctxLock.Lock() + defer g.ctxLock.Unlock() + + return g.ctx.Done() +} + +// IsDone checks whether the manager context is done. +func (g *Group) IsDone() bool { + g.ctxLock.Lock() + defer g.ctxLock.Unlock() + + return g.ctx.Err() != nil +} + +// RunModules is a simple wrapper function to start modules and stop them again +// when the given context is canceled. +func RunModules(ctx context.Context, modules ...Module) error { + g := NewGroup(modules...) + + // Start module. + if err := g.Start(); err != nil { + return fmt.Errorf("failed to start: %w", err) + } + + // Stop module when context is canceled. + <-ctx.Done() + if !g.Stop() { + return errors.New("failed to stop") + } + + return nil +} + +func makeModuleName(m Module) string { + return strings.TrimPrefix(fmt.Sprintf("%T", m), "*") +} diff --git a/service/mgr/worker.go b/service/mgr/worker.go new file mode 100644 index 000000000..109716f9b --- /dev/null +++ b/service/mgr/worker.go @@ -0,0 +1,285 @@ +package mgr + +import ( + "context" + "errors" + "fmt" + "log/slog" + "os" + "runtime/debug" + "strings" + "time" +) + +// workerContextKey is a key used for the context key/value storage. +type workerContextKey struct{} + +// WorkerCtxContextKey is the key used to add the WorkerCtx to a context. +var WorkerCtxContextKey = workerContextKey{} + +// WorkerCtx provides workers with the necessary environment for flow control +// and logging. +type WorkerCtx struct { + ctx context.Context + cancelCtx context.CancelFunc + + logger *slog.Logger +} + +// AddToCtx adds the WorkerCtx to the given context. +func (w *WorkerCtx) AddToCtx(ctx context.Context) context.Context { + return context.WithValue(ctx, WorkerCtxContextKey, w) +} + +// WorkerFromCtx returns the WorkerCtx from the given context. +func WorkerFromCtx(ctx context.Context) *WorkerCtx { + v := ctx.Value(WorkerCtxContextKey) + if w, ok := v.(*WorkerCtx); ok { + return w + } + return nil +} + +// Ctx returns the worker context. +// Is automatically canceled after the worker stops/returns, regardless of error. +func (w *WorkerCtx) Ctx() context.Context { + return w.ctx +} + +// Cancel cancels the worker context. +// Is automatically called after the worker stops/returns, regardless of error. +func (w *WorkerCtx) Cancel() { + w.cancelCtx() +} + +// Done returns the context Done channel. +func (w *WorkerCtx) Done() <-chan struct{} { + return w.ctx.Done() +} + +// IsDone checks whether the worker context is done. +func (w *WorkerCtx) IsDone() bool { + return w.ctx.Err() != nil +} + +// Logger returns the logger used by the worker context. +func (w *WorkerCtx) Logger() *slog.Logger { + return w.logger +} + +// LogEnabled reports whether the logger emits log records at the given level. +// The worker context is automatically supplied. +func (w *WorkerCtx) LogEnabled(level slog.Level) bool { + return w.logger.Enabled(w.ctx, level) +} + +// Debug logs at LevelDebug. +// The worker context is automatically supplied. +func (w *WorkerCtx) Debug(msg string, args ...any) { + w.logger.DebugContext(w.ctx, msg, args...) +} + +// Info logs at LevelInfo. +// The worker context is automatically supplied. +func (w *WorkerCtx) Info(msg string, args ...any) { + w.logger.InfoContext(w.ctx, msg, args...) +} + +// Warn logs at LevelWarn. +// The worker context is automatically supplied. +func (w *WorkerCtx) Warn(msg string, args ...any) { + w.logger.WarnContext(w.ctx, msg, args...) +} + +// Error logs at LevelError. +// The worker context is automatically supplied. +func (w *WorkerCtx) Error(msg string, args ...any) { + w.logger.ErrorContext(w.ctx, msg, args...) +} + +// Log emits a log record with the current time and the given level and message. +// The worker context is automatically supplied. +func (w *WorkerCtx) Log(level slog.Level, msg string, args ...any) { + w.logger.Log(w.ctx, level, msg, args...) +} + +// LogAttrs is a more efficient version of Log() that accepts only Attrs. +// The worker context is automatically supplied. +func (w *WorkerCtx) LogAttrs(level slog.Level, msg string, attrs ...slog.Attr) { + w.logger.LogAttrs(w.ctx, level, msg, attrs...) +} + +// Go starts the given function in a goroutine (as a "worker"). +// The worker context has +// - A separate context which is canceled when the functions returns. +// - Access to named structure logging. +// - Given function is re-run after failure (with backoff). +// - Panic catching. +// - Flow control helpers. +func (m *Manager) Go(name string, fn func(w *WorkerCtx) error) { + go m.manageWorker(name, fn) +} + +func (m *Manager) manageWorker(name string, fn func(w *WorkerCtx) error) { + m.workerStart() + defer m.workerDone() + + w := &WorkerCtx{ + logger: m.logger.With("worker", name), + } + + backoff := time.Second + failCnt := 0 + + for { + panicInfo, err := m.runWorker(w, fn) + switch { + case err == nil: + // No error means that the worker is finished. + return + + case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): + // A canceled context or dexceeded eadline also means that the worker is finished. + return + + default: + // Any other errors triggers a restart with backoff. + + // If manager is stopping, just log error and return. + if m.IsDone() { + if panicInfo != "" { + m.Error( + "worker failed", + "err", err, + "file", panicInfo, + ) + } else { + m.Error( + "worker failed", + "err", err, + ) + } + return + } + + // Count failure and increase backoff (up to limit), + failCnt++ + backoff *= 2 + if backoff > time.Minute { + backoff = time.Minute + } + + // Log error and retry after backoff duration. + if panicInfo != "" { + m.Error( + "worker failed", + "failCnt", failCnt, + "backoff", backoff, + "err", err, + "file", panicInfo, + ) + } else { + m.Error( + "worker failed", + "failCnt", failCnt, + "backoff", backoff, + "err", err, + ) + } + select { + case <-time.After(backoff): + case <-m.ctx.Done(): + return + } + } + } +} + +// Do directly executes the given function (as a "worker"). +// The worker context has +// - A separate context which is canceled when the functions returns. +// - Access to named structure logging. +// - Given function is re-run after failure (with backoff). +// - Panic catching. +// - Flow control helpers. +func (m *Manager) Do(name string, fn func(w *WorkerCtx) error) error { + m.workerStart() + defer m.workerDone() + + // Create context. + w := &WorkerCtx{ + logger: m.logger.With("worker", name), + } + + // Run worker. + panicInfo, err := m.runWorker(w, fn) + switch { + case err == nil: + // No error means that the worker is finished. + return nil + + case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): + // A canceled context or dexceeded eadline also means that the worker is finished. + return err + + default: + // Log error and return. + if panicInfo != "" { + m.Error( + "worker failed", + "err", err, + "file", panicInfo, + ) + } else { + m.Error( + "worker failed", + "err", err, + ) + } + return err + } +} + +func (m *Manager) runWorker(w *WorkerCtx, fn func(w *WorkerCtx) error) (panicInfo string, err error) { + // Create worker context that is canceled when worker finished or dies. + w.ctx, w.cancelCtx = context.WithCancel(m.Ctx()) + defer w.Cancel() + + // Recover from panic. + defer func() { + panicVal := recover() + if panicVal != nil { + err = fmt.Errorf("panic: %s", panicVal) + + // Print panic to stderr. + stackTrace := string(debug.Stack()) + fmt.Fprintf( + os.Stderr, + "===== PANIC =====\n%s\n\n%s===== END =====\n", + panicVal, + stackTrace, + ) + + // Find the line in the stack trace that refers to where the panic occurred. + stackLines := strings.Split(stackTrace, "\n") + foundPanic := false + for i, line := range stackLines { + if !foundPanic { + if strings.Contains(line, "panic(") { + foundPanic = true + } + } else { + if strings.Contains(line, "mycoria") { + if i+1 < len(stackLines) { + panicInfo = strings.SplitN(strings.TrimSpace(stackLines[i+1]), " ", 2)[0] + } + break + } + } + } + } + }() + + err = fn(w) + return //nolint +} From 4f90afe8c8eae43340f17bef9eafcdf86caac759 Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 5 Jun 2024 15:50:37 +0200 Subject: [PATCH 03/56] [WIP] Switch to new simple module mgr --- base/api/module.go | 45 ++++++++++++++++++++++++++++++++++++ service/instance.go | 55 ++++++++++++++++++++++++++++++++++++++++++++ service/ui/api.go | 1 - service/ui/module.go | 49 +++++++++++++++++++++++++++++++-------- service/ui/serve.go | 2 -- 5 files changed, 140 insertions(+), 12 deletions(-) create mode 100644 base/api/module.go create mode 100644 service/instance.go diff --git a/base/api/module.go b/base/api/module.go new file mode 100644 index 000000000..7aa71a861 --- /dev/null +++ b/base/api/module.go @@ -0,0 +1,45 @@ +package api + +import ( + "errors" + "sync/atomic" + + "github.com/safing/portmaster/service/mgr" +) + +// API is the HTTP/Websockets API module. +type API struct { + instance instance +} + +// Start starts the module. +func (api *API) Start(_ *mgr.Manager) error { + return start() +} + +// Stop stops the module. +func (api *API) Stop(_ *mgr.Manager) error { + return start() +} + +var ( + shimLoaded atomic.Bool +) + +// New returns a new UI module. +func New(instance instance) (*API, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + if err := prep(); err != nil { + return nil, err + } + + return &API{ + instance: instance, + }, nil +} + +type instance interface { +} diff --git a/service/instance.go b/service/instance.go new file mode 100644 index 000000000..1cbf8aa7f --- /dev/null +++ b/service/instance.go @@ -0,0 +1,55 @@ +package service + +import ( + "fmt" + + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/service/mgr" + "github.com/safing/portmaster/service/ui" +) + +// Instance is an instance of a mycoria router. +type Instance struct { + *mgr.Group + + version string + + api *api.API + ui *ui.UI +} + +// New returns a new mycoria router instance. +func New(version string) (*Instance, error) { + // Create instance to pass it to modules. + instance := &Instance{ + version: version, + } + + var err error + instance.ui, err = ui.New(instance) + if err != nil { + return nil, fmt.Errorf("create ui module: %w", err) + } + + // Add all modules to instance group. + instance.Group = mgr.NewGroup( + instance.ui, + ) + + return instance, nil +} + +// Version returns the version. +func (i *Instance) Version() string { + return i.version +} + +// API returns the api module. +func (i *Instance) API() *api.API { + return i.api +} + +// UI returns the ui module. +func (i *Instance) UI() *ui.UI { + return i.ui +} diff --git a/service/ui/api.go b/service/ui/api.go index 5e57dfe5c..de7e77585 100644 --- a/service/ui/api.go +++ b/service/ui/api.go @@ -9,7 +9,6 @@ func registerAPIEndpoints() error { return api.RegisterEndpoint(api.Endpoint{ Path: "ui/reload", Write: api.PermitUser, - BelongsTo: module, ActionFunc: reloadUI, Name: "Reload UI Assets", Description: "Removes all assets from the cache and reloads the current (possibly updated) version from disk when requested.", diff --git a/service/ui/module.go b/service/ui/module.go index 0e10ed757..aa24e557a 100644 --- a/service/ui/module.go +++ b/service/ui/module.go @@ -1,16 +1,14 @@ package ui import ( - "github.com/safing/portbase/dataroot" - "github.com/safing/portbase/log" - "github.com/safing/portbase/modules" -) - -var module *modules.Module + "errors" + "sync/atomic" -func init() { - module = modules.Register("ui", prep, start, nil, "api", "updates") -} + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/dataroot" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" +) func prep() error { if err := registerAPIEndpoints(); err != nil { @@ -36,3 +34,36 @@ func start() error { return nil } + +// UI module the user interface files. +type UI struct { + instance instance +} + +// Start starts the module. +func (ui *UI) Start(_ *mgr.Manager) error { + return start() +} + +// Stop stops the module. +func (ui *UI) Stop(_ *mgr.Manager) error { + return start() +} + +var ( + shimLoaded atomic.Bool +) + +// New returns a new UI module. +func New(instance instance) (*UI, error) { + if shimLoaded.CompareAndSwap(false, true) { + return &UI{ + instance: instance, + }, nil + } + return nil, errors.New("only one instance allowed") +} + +type instance interface { + API() *api.API +} diff --git a/service/ui/serve.go b/service/ui/serve.go index d8c9f5f26..9dca6c309 100644 --- a/service/ui/serve.go +++ b/service/ui/serve.go @@ -56,8 +56,6 @@ type archiveServer struct { defaultModuleName string } -func (bs *archiveServer) BelongsTo() *modules.Module { return module } - func (bs *archiveServer) ReadPermission(*http.Request) api.Permission { return api.PermitAnyone } func (bs *archiveServer) WritePermission(*http.Request) api.Permission { return api.NotSupported } From 311536cec975a36bccafaaed3dced84b3d46f238 Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 6 Jun 2024 17:32:26 +0200 Subject: [PATCH 04/56] Add StateMgr and more worker variants --- service/mgr/states.go | 125 ++++++++++++++++++++++++++++++++++++++++++ service/mgr/worker.go | 125 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 250 insertions(+) create mode 100644 service/mgr/states.go diff --git a/service/mgr/states.go b/service/mgr/states.go new file mode 100644 index 000000000..229c395c3 --- /dev/null +++ b/service/mgr/states.go @@ -0,0 +1,125 @@ +package mgr + +import ( + "slices" + "sync" + "time" +) + +// StateMgr is a simple state manager. +type StateMgr struct { + states []State + statesLock sync.Mutex + + statesEventMgr *EventMgr[StateUpdate] + + mgr *Manager +} + +// State describes the state of a manager or module. +type State struct { + ID string // Required. + Name string // Required. + Message string // Optional. + Type StateType // Optional. + Time time.Time // Optional, will be set to current time if not set. + Data any // Optional. +} + +// StateType defines commonly used states. +type StateType string + +// State Types. +const ( + StateTypeUndefined = "" + StateTypeHint = "hint" + StateTypeWarning = "warning" + StateTypeError = "error" +) + +// StateUpdate is used to update others about a state change. +type StateUpdate struct { + Name string + States []State +} + +// NewStateMgr returns a new event manager. +// It is easiest used as a public field on a struct, +// so that others can simply Subscribe() oder AddCallback(). +func NewStateMgr(mgr *Manager) *StateMgr { + return &StateMgr{ + statesEventMgr: NewEventMgr[StateUpdate]("state update", mgr), + mgr: mgr, + } +} + +// Add adds a state. +// If a state with the same ID already exists, it is replaced. +func (m *StateMgr) Add(s State) { + m.statesLock.Lock() + defer m.statesLock.Unlock() + + if s.Time.IsZero() { + s.Time = time.Now() + } + + // Update or add state. + index := slices.IndexFunc[[]State, State](m.states, func(es State) bool { + return es.ID == s.ID + }) + if index > 0 { + m.states[index] = s + } else { + m.states = append(m.states, s) + } + + m.statesEventMgr.Submit(m.Export()) +} + +// Remove removes the state with the given ID. +func (m *StateMgr) Remove(id string) { + m.statesLock.Lock() + defer m.statesLock.Unlock() + + slices.DeleteFunc[[]State, State](m.states, func(s State) bool { + return s.ID == id + }) + + m.statesEventMgr.Submit(m.Export()) +} + +// Clear removes all states. +func (m *StateMgr) Clear() { + m.statesLock.Lock() + defer m.statesLock.Unlock() + + m.states = nil + + m.statesEventMgr.Submit(m.Export()) +} + +// Export returns the current states. +func (m *StateMgr) Export() StateUpdate { + m.statesLock.Lock() + defer m.statesLock.Unlock() + + name := "" + if m.mgr != nil { + name = m.mgr.name + } + + return StateUpdate{ + Name: name, + States: slices.Clone(m.states), + } +} + +// Subscribe subscribes to state update events. +func (m *StateMgr) Subscribe(subscriberName string, chanSize int) *EventSubscription[StateUpdate] { + return m.statesEventMgr.Subscribe(subscriberName, chanSize) +} + +// AddCallback adds a callback to state update events. +func (m *StateMgr) AddCallback(callbackName string, callback EventCallbackFunc[StateUpdate]) { + m.statesEventMgr.AddCallback(callbackName, callback) +} diff --git a/service/mgr/worker.go b/service/mgr/worker.go index 109716f9b..15464442f 100644 --- a/service/mgr/worker.go +++ b/service/mgr/worker.go @@ -195,6 +195,26 @@ func (m *Manager) manageWorker(name string, fn func(w *WorkerCtx) error) { } } +// Delay starts the given function delayed in a goroutine (as a "worker"). +// The worker context has +// - A separate context which is canceled when the functions returns. +// - Access to named structure logging. +// - Given function is re-run after failure (with backoff). +// - Panic catching. +// - Flow control helpers. +func (m *Manager) Delay(name string, delay time.Duration, fn func(w *WorkerCtx) error) { + go m.delayWorker(name, delay, fn) +} + +func (m *Manager) delayWorker(name string, delay time.Duration, fn func(w *WorkerCtx) error) { + select { + case <-time.After(delay): + case <-m.ctx.Done(): + return + } + m.manageWorker(name, fn) +} + // Do directly executes the given function (as a "worker"). // The worker context has // - A separate context which is canceled when the functions returns. @@ -283,3 +303,108 @@ func (m *Manager) runWorker(w *WorkerCtx, fn func(w *WorkerCtx) error) (panicInf err = fn(w) return //nolint } + +// Repeat executes the given function periodically in a goroutine (as a "worker"). +// The worker context has +// - A separate context which is canceled when the functions returns. +// - Access to named structure logging. +// - Given function is re-run after failure (with backoff). +// - Panic catching. +// - Flow control helpers. +func (m *Manager) Repeat(name string, period time.Duration, fn func(w *WorkerCtx) error) { + go m.manageRepeatedWorker(name, period, fn) +} + +func (m *Manager) manageRepeatedWorker(name string, period time.Duration, fn func(w *WorkerCtx) error) { + m.workerStart() + defer m.workerDone() + + w := &WorkerCtx{ + logger: m.logger.With("worker", name), + } + + repeatTick := time.NewTicker(period) + execCnt := 0 + + backoff := time.Second + failCnt := 0 + +repeat: + for { + // Wait for repeat period. + if execCnt > 0 { + select { + case <-repeatTick.C: + case <-m.ctx.Done(): + return + } + } + + // Execute function. + execCnt++ + panicInfo, err := m.runWorker(w, fn) + + switch { + case err == nil: + // No error means that the worker is finished. + continue repeat + + case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): + // A canceled context or dexceeded eadline also means that the worker is finished. + continue repeat + + default: + // Any other errors triggers a restart with backoff. + + // If manager is stopping, just log error and return. + if m.IsDone() { + if panicInfo != "" { + m.Error( + "worker failed", + "err", err, + "file", panicInfo, + ) + } else { + m.Error( + "worker failed", + "err", err, + ) + } + return + } + + // Count failure and increase backoff (up to limit), + failCnt++ + backoff *= 2 + if backoff > time.Minute { + backoff = time.Minute + } + + // Log error and retry after backoff duration. + if panicInfo != "" { + m.Error( + "repeated worker failed", + "execCnt", execCnt, + "failCnt", failCnt, + "backoff", backoff, + "err", err, + "file", panicInfo, + ) + } else { + m.Error( + "repeated worker failed", + "execCnt", execCnt, + "failCnt", failCnt, + "backoff", backoff, + "err", err, + ) + } + select { + case <-time.After(backoff): + case <-m.ctx.Done(): + return + } + repeatTick.Reset(period) + } + } +} From 469aef80b6e176018bd8e9a38971302dbc0be15e Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 6 Jun 2024 17:32:55 +0200 Subject: [PATCH 05/56] [WIP] Switch more modules --- base/config/basic_config.go | 15 +++------- base/config/main.go | 8 ++--- base/config/module.go | 53 ++++++++++++++++++++++++++++++++ base/config/set.go | 2 +- service/config.go | 5 ++++ service/instance.go | 30 +++++++++++++++++-- service/ui/module.go | 9 ++++-- service/updates/api.go | 6 ++-- service/updates/config.go | 11 ++++--- service/updates/export.go | 20 +++++-------- service/updates/get.go | 4 +-- service/updates/main.go | 43 +++++++++++++------------- service/updates/module.go | 60 +++++++++++++++++++++++++++++++++++++ service/updates/notify.go | 2 +- service/updates/restart.go | 13 ++++---- service/updates/upgrader.go | 17 +++++------ 16 files changed, 212 insertions(+), 86 deletions(-) create mode 100644 base/config/module.go create mode 100644 service/config.go create mode 100644 service/updates/module.go diff --git a/base/config/basic_config.go b/base/config/basic_config.go index 7898df127..2bce18b4f 100644 --- a/base/config/basic_config.go +++ b/base/config/basic_config.go @@ -1,10 +1,10 @@ package config import ( - "context" "flag" "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" ) // Configuration Keys. @@ -78,14 +78,7 @@ func registerBasicOptions() error { logLevel = GetAsString(CfgLogLevel, defaultLogLevel) // Register to hook to update the log level. - if err := module.RegisterEventHook( - "config", - ChangeEvent, - "update log level", - setLogLevel, - ); err != nil { - return err - } + module.EventConfigChange.AddCallback("update log level", setLogLevel) return Register(&Option{ Name: "Development Mode", @@ -106,8 +99,8 @@ func loadLogLevel() error { return setDefaultConfigOption(CfgLogLevel, log.GetLogLevel().Name(), false) } -func setLogLevel(ctx context.Context, data interface{}) error { +func setLogLevel(_ *mgr.WorkerCtx, _ struct{}) (cancel bool, err error) { log.SetLogLevel(log.ParseLevel(logLevel())) - return nil + return false, nil } diff --git a/base/config/main.go b/base/config/main.go index 324c01f92..671e8a2b3 100644 --- a/base/config/main.go +++ b/base/config/main.go @@ -10,8 +10,8 @@ import ( "path/filepath" "sort" + "github.com/safing/portbase/modules" "github.com/safing/portmaster/base/dataroot" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/base/utils" "github.com/safing/portmaster/base/utils/debug" ) @@ -20,7 +20,7 @@ import ( const ChangeEvent = "config change" var ( - module *modules.Module + // module *modules.Module dataRoot *utils.DirStructure exportConfig bool @@ -34,8 +34,8 @@ func SetDataRoot(root *utils.DirStructure) { } func init() { - module = modules.Register("config", prep, start, nil, "database") - module.RegisterEvent(ChangeEvent, true) + // module = modules.Register("config", prep, start, nil, "database") + // module.RegisterEvent(ChangeEvent, true) flag.BoolVar(&exportConfig, "export-config-options", false, "export configuration registry and exit") } diff --git a/base/config/module.go b/base/config/module.go new file mode 100644 index 000000000..f7ed07b84 --- /dev/null +++ b/base/config/module.go @@ -0,0 +1,53 @@ +package config + +import ( + "errors" + "sync/atomic" + + "github.com/safing/portmaster/service/mgr" +) + +// Config provides configuration mgmt. +type Config struct { + mgr *mgr.Manager + + instance instance + + EventConfigChange *mgr.EventMgr[struct{}] +} + +// Start starts the module. +func (u *Config) Start(m *mgr.Manager) error { + u.mgr = m + u.EventConfigChange = mgr.NewEventMgr[struct{}](ChangeEvent, u.mgr) + return start() +} + +// Stop stops the module. +func (u *Config) Stop(_ *mgr.Manager) error { + return nil +} + +var ( + module *Config + shimLoaded atomic.Bool +) + +// New returns a new UI module. +func New(instance instance) (*Config, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + if err := prep(); err != nil { + return nil, err + } + + module = &Config{ + instance: instance, + } + return module, nil +} + +type instance interface { +} diff --git a/base/config/set.go b/base/config/set.go index 2c40ca213..8913b0e6e 100644 --- a/base/config/set.go +++ b/base/config/set.go @@ -34,7 +34,7 @@ func signalChanges() { validityFlag = abool.NewBool(true) validityFlagLock.Unlock() - module.TriggerEvent(ChangeEvent, nil) + module.EventConfigChange.Submit(struct{}{}) } // ValidateConfig validates the given configuration and returns all validation diff --git a/service/config.go b/service/config.go new file mode 100644 index 000000000..5c6884348 --- /dev/null +++ b/service/config.go @@ -0,0 +1,5 @@ +package service + +type ServiceConfig struct { + ShutdownFunc func(exitCode int) +} diff --git a/service/instance.go b/service/instance.go index 1cbf8aa7f..bda1ffcb8 100644 --- a/service/instance.go +++ b/service/instance.go @@ -4,8 +4,10 @@ import ( "fmt" "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/ui" + "github.com/safing/portmaster/service/updates" ) // Instance is an instance of a mycoria router. @@ -14,18 +16,32 @@ type Instance struct { version string - api *api.API - ui *ui.UI + api *api.API + ui *ui.UI + updates *updates.Updates + config *config.Config } // New returns a new mycoria router instance. -func New(version string) (*Instance, error) { +func New(version string, svcCfg *ServiceConfig) (*Instance, error) { // Create instance to pass it to modules. instance := &Instance{ version: version, } var err error + instance.config, err = config.New(instance) + if err != nil { + return nil, fmt.Errorf("create config module: %w", err) + } + instance.api, err = api.New(instance) + if err != nil { + return nil, fmt.Errorf("create api module: %w", err) + } + instance.updates, err = updates.New(instance, svcCfg.ShutdownFunc) + if err != nil { + return nil, fmt.Errorf("create updates module: %w", err) + } instance.ui, err = ui.New(instance) if err != nil { return nil, fmt.Errorf("create ui module: %w", err) @@ -33,6 +49,9 @@ func New(version string) (*Instance, error) { // Add all modules to instance group. instance.Group = mgr.NewGroup( + instance.config, + instance.api, + instance.updates, instance.ui, ) @@ -53,3 +72,8 @@ func (i *Instance) API() *api.API { func (i *Instance) UI() *ui.UI { return i.ui } + +// Config returns the config module. +func (i *Instance) Config() *config.Config { + return i.config +} diff --git a/service/ui/module.go b/service/ui/module.go index aa24e557a..7a1691258 100644 --- a/service/ui/module.go +++ b/service/ui/module.go @@ -35,19 +35,22 @@ func start() error { return nil } -// UI module the user interface files. +// UI serves the user interface files. type UI struct { + mgr *mgr.Manager + instance instance } // Start starts the module. -func (ui *UI) Start(_ *mgr.Manager) error { +func (ui *UI) Start(m *mgr.Manager) error { + ui.mgr = m return start() } // Stop stops the module. func (ui *UI) Stop(_ *mgr.Manager) error { - return start() + return stop() } var ( diff --git a/service/updates/api.go b/service/updates/api.go index 5917ec4bb..886596203 100644 --- a/service/updates/api.go +++ b/service/updates/api.go @@ -29,9 +29,8 @@ func registerAPIEndpoints() error { Value: "", Description: "Force downloading and applying of all updates, regardless of auto-update settings.", }}, - Path: apiPathCheckForUpdates, - Write: api.PermitUser, - BelongsTo: module, + Path: apiPathCheckForUpdates, + Write: api.PermitUser, ActionFunc: func(r *api.Request) (msg string, err error) { // Check if we should also download regardless of settings. downloadAll := r.URL.Query().Has("download") @@ -58,7 +57,6 @@ func registerAPIEndpoints() error { Path: `updates/get/{identifier:[A-Za-z0-9/\.\-_]{1,255}}`, Read: api.PermitUser, ReadMethod: http.MethodGet, - BelongsTo: module, HandlerFunc: func(w http.ResponseWriter, r *http.Request) { // Get identifier from URL. var identifier string diff --git a/service/updates/config.go b/service/updates/config.go index 43f815ae9..d51f14187 100644 --- a/service/updates/config.go +++ b/service/updates/config.go @@ -1,12 +1,11 @@ package updates import ( - "context" - "github.com/tevino/abool" "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/updates/helper" ) @@ -123,7 +122,7 @@ func initConfig() { previousDevMode = devMode() } -func updateRegistryConfig(_ context.Context, _ interface{}) error { +func updateRegistryConfig(_ *mgr.WorkerCtx, _ struct{}) (cancel bool, err error) { changed := false if enableSoftwareUpdates() != softwareUpdatesCurrentlyEnabled { @@ -162,10 +161,10 @@ func updateRegistryConfig(_ context.Context, _ interface{}) error { // Select versions depending on new indexes and modes. registry.SelectVersions() - module.TriggerEvent(VersionUpdateEvent, nil) + module.EventVersionsUpdated.Submit(struct{}{}) if softwareUpdatesCurrentlyEnabled || intelUpdatesCurrentlyEnabled { - module.Resolve("") + module.States.Clear() if err := TriggerUpdate(true, false); err != nil { log.Warningf("updates: failed to trigger update: %s", err) } @@ -175,5 +174,5 @@ func updateRegistryConfig(_ context.Context, _ interface{}) error { } } - return nil + return false, nil } diff --git a/service/updates/export.go b/service/updates/export.go index 98ae4ba99..c230f3672 100644 --- a/service/updates/export.go +++ b/service/updates/export.go @@ -1,7 +1,6 @@ package updates import ( - "context" "fmt" "sort" "strings" @@ -12,6 +11,7 @@ import ( "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/updater" "github.com/safing/portmaster/base/utils/debug" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/updates/helper" ) @@ -152,12 +152,8 @@ func initVersionExport() (err error) { log.Warningf("updates: failed to export version information: %s", err) } - return module.RegisterEventHook( - ModuleName, - VersionUpdateEvent, - "export version status", - export, - ) + module.EventVersionsUpdated.AddCallback("export version status", export) + return nil } func (v *Versions) save() error { @@ -182,20 +178,20 @@ func (s *UpdateStateExport) save() error { } // export is an event hook. -func export(_ context.Context, _ interface{}) error { +func export(_ *mgr.WorkerCtx, _ struct{}) (cancel bool, err error) { // Export versions. if err := GetVersions().save(); err != nil { - return err + return false, err } if err := GetSimpleVersions().save(); err != nil { - return err + return false, err } // Export udpate state. if err := GetStateExport().save(); err != nil { - return err + return false, err } - return nil + return false, nil } // AddToDebugInfo adds the update system status to the given debug.Info. diff --git a/service/updates/get.go b/service/updates/get.go index 4a35535fd..bac9ae148 100644 --- a/service/updates/get.go +++ b/service/updates/get.go @@ -16,7 +16,7 @@ func GetPlatformFile(identifier string) (*updater.File, error) { return nil, err } - module.TriggerEvent(VersionUpdateEvent, nil) + module.EventVersionsUpdated.Submit(struct{}{}) return file, nil } @@ -29,7 +29,7 @@ func GetFile(identifier string) (*updater.File, error) { return nil, err } - module.TriggerEvent(VersionUpdateEvent, nil) + module.EventVersionsUpdated.Submit(struct{}{}) return file, nil } diff --git a/service/updates/main.go b/service/updates/main.go index a8d500388..46bfdc09d 100644 --- a/service/updates/main.go +++ b/service/updates/main.go @@ -9,10 +9,10 @@ import ( "runtime" "time" + "github.com/safing/portbase/modules" "github.com/safing/portmaster/base/database" "github.com/safing/portmaster/base/dataroot" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/base/updater" "github.com/safing/portmaster/service/updates/helper" ) @@ -43,7 +43,7 @@ const ( ) var ( - module *modules.Module + // module *modules.Module registry *updater.ResourceRegistry userAgentFromFlag string @@ -80,9 +80,10 @@ const ( ) func init() { - module = modules.Register(ModuleName, prep, start, stop, "base") - module.RegisterEvent(VersionUpdateEvent, true) - module.RegisterEvent(ResourceUpdateEvent, true) + // FIXME: + // module = modules.Register(ModuleName, prep, start, stop, "base") + // module.RegisterEvent(VersionUpdateEvent, true) + // module.RegisterEvent(ResourceUpdateEvent, true) flag.StringVar(&updateServerFromFlag, "update-server", "", "set an alternative update server (full URL)") flag.StringVar(&userAgentFromFlag, "update-agent", "", "set an alternative user agent for requests to the update server") @@ -110,15 +111,9 @@ func prep() error { func start() error { initConfig() - restartTask = module.NewTask("automatic restart", automaticRestart).MaxDelay(10 * time.Minute) + module.mgr.Repeat("automatic restart", 10*time.Minute, automaticRestart) - if err := module.RegisterEventHook( - "config", - "config change", - "update registry config", - updateRegistryConfig); err != nil { - return err - } + module.instance.Config().EventConfigChange.AddCallback("update registry config", updateRegistryConfig) // create registry registry = &updater.ResourceRegistry{ @@ -175,7 +170,7 @@ func start() error { log.Warningf("updates: %s", warning) } - err = registry.LoadIndexes(module.Ctx) + err = registry.LoadIndexes(module.mgr.Ctx()) if err != nil { log.Warningf("updates: failed to load indexes: %s", err) } @@ -186,7 +181,7 @@ func start() error { } registry.SelectVersions() - module.TriggerEvent(VersionUpdateEvent, nil) + module.EventVersionsUpdated.Submit(struct{}{}) // Initialize the version export - this requires the registry to be set up. err = initVersionExport() @@ -195,14 +190,18 @@ func start() error { } // start updater task - updateTask = module.NewTask("updater", func(ctx context.Context, task *modules.Task) error { - return checkForUpdates(ctx) - }) + // FIXME: remove + // updateTask = module.NewTask("updater", func(ctx context.Context, task *modules.Task) error { + // return checkForUpdates(ctx) + // }) if !disableTaskSchedule { - updateTask. - Repeat(updateTaskRepeatDuration). - MaxDelay(30 * time.Minute) + module.mgr.Repeat("updater", 30*time.Minute, checkForUpdates) + + // FIXME: remove + // updateTask. + // Repeat(updateTaskRepeatDuration). + // MaxDelay(30 * time.Minute) } if updateASAP { @@ -318,7 +317,7 @@ func checkForUpdates(ctx context.Context) (err error) { // Purge old resources registry.Purge(2) - module.TriggerEvent(ResourceUpdateEvent, nil) + module.EventResourcesUpdated.Submit(struct{}{}) return nil } diff --git a/service/updates/module.go b/service/updates/module.go new file mode 100644 index 000000000..8e62bc67b --- /dev/null +++ b/service/updates/module.go @@ -0,0 +1,60 @@ +package updates + +import ( + "errors" + "sync/atomic" + + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/service/mgr" +) + +// Updates provides access to released artifacts. +type Updates struct { + mgr *mgr.Manager + + instance instance + shutdownFunc func(exitCode int) + + EventResourcesUpdated *mgr.EventMgr[struct{}] + EventVersionsUpdated *mgr.EventMgr[struct{}] + + States *mgr.StateMgr +} + +// Start starts the module. +func (u *Updates) Start(m *mgr.Manager) error { + u.mgr = m + u.EventResourcesUpdated = mgr.NewEventMgr[struct{}](ResourceUpdateEvent, u.mgr) + u.EventVersionsUpdated = mgr.NewEventMgr[struct{}](VersionUpdateEvent, u.mgr) + u.States = mgr.NewStateMgr(u.mgr) + + return start() +} + +// Stop stops the module. +func (u *Updates) Stop(_ *mgr.Manager) error { + return stop() +} + +var ( + module *Updates + shimLoaded atomic.Bool +) + +// New returns a new UI module. +func New(instance instance, shutdownFunc func(exitCode int)) (*Updates, error) { + if shimLoaded.CompareAndSwap(false, true) { + module = &Updates{ + instance: instance, + shutdownFunc: shutdownFunc, + } + return module, nil + } + return nil, errors.New("only one instance allowed") +} + +type instance interface { + API() *api.API + Config() *config.Config +} diff --git a/service/updates/notify.go b/service/updates/notify.go index 01dde5234..662a1b82c 100644 --- a/service/updates/notify.go +++ b/service/updates/notify.go @@ -23,7 +23,7 @@ var updateFailedCnt = new(atomic.Int32) func notifyUpdateSuccess(force bool) { updateFailedCnt.Store(0) - module.Resolve(updateFailed) + module.States.Clear() updateState := registry.GetState().Updates flavor := updateSuccess diff --git a/service/updates/restart.go b/service/updates/restart.go index b08fdf30c..f219b1ef1 100644 --- a/service/updates/restart.go +++ b/service/updates/restart.go @@ -1,7 +1,6 @@ package updates import ( - "context" "os/exec" "runtime" "sync" @@ -9,8 +8,9 @@ import ( "github.com/tevino/abool" + "github.com/safing/portbase/modules" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/service/mgr" ) const ( @@ -94,7 +94,7 @@ func RestartNow() { restartTask.StartASAP() } -func automaticRestart(_ context.Context, _ *modules.Task) error { +func automaticRestart(w *mgr.WorkerCtx) error { // Check if the restart is still scheduled. if restartPending.IsNotSet() { return nil @@ -116,11 +116,10 @@ func automaticRestart(_ context.Context, _ *modules.Task) error { // Set restart exit code. if !rebooting { - modules.SetExitStatusCode(RestartExitCode) + module.shutdownFunc(RestartExitCode) + } else { + module.shutdownFunc(0) } - - // Do not use a worker, as this would block itself here. - go modules.Shutdown() //nolint:errcheck } return nil diff --git a/service/updates/upgrader.go b/service/updates/upgrader.go index 1dc64f16d..889ccf4c9 100644 --- a/service/updates/upgrader.go +++ b/service/updates/upgrader.go @@ -21,6 +21,7 @@ import ( "github.com/safing/portmaster/base/rng" "github.com/safing/portmaster/base/updater" "github.com/safing/portmaster/base/utils/renameio" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/updates/helper" ) @@ -41,23 +42,19 @@ var ( ) func initUpgrader() error { - return module.RegisterEventHook( - ModuleName, - ResourceUpdateEvent, - "run upgrades", - upgrader, - ) + module.EventResourcesUpdated.AddCallback("run upgrades", upgrader) + return nil } -func upgrader(_ context.Context, _ interface{}) error { +func upgrader(m *mgr.WorkerCtx, _ struct{}) (cancel bool, err error) { // Lock runs, but discard additional runs. if !upgraderActive.SetToIf(false, true) { - return nil + return false, nil } defer upgraderActive.SetTo(false) // Upgrade portmaster-start. - err := upgradePortmasterStart() + err = upgradePortmasterStart() if err != nil { log.Warningf("updates: failed to upgrade portmaster-start: %s", err) } @@ -86,7 +83,7 @@ func upgrader(_ context.Context, _ interface{}) error { } } - return nil + return false, nil } func upgradeCoreNotify() error { From 2831227c2eac11fc0f803785f57cb435f3b500db Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov Date: Fri, 7 Jun 2024 17:28:44 +0300 Subject: [PATCH 06/56] [WIP] Switch more modules --- base/info/module/flags.go | 32 ++++++++-- base/metrics/api.go | 5 +- base/metrics/module.go | 48 +++++++++++---- service/firewall/module.go | 99 ++++++++++++++++-------------- service/firewall/packet_handler.go | 15 ++--- service/instance.go | 46 +++++++++++++- service/mgr/events.go | 2 +- service/profile/active.go | 4 +- service/profile/database.go | 8 +-- service/profile/module.go | 61 +++++++++++++++--- 10 files changed, 232 insertions(+), 88 deletions(-) diff --git a/base/info/module/flags.go b/base/info/module/flags.go index f1c8af230..2df0bbf2a 100644 --- a/base/info/module/flags.go +++ b/base/info/module/flags.go @@ -1,22 +1,27 @@ package module import ( + "errors" "flag" "fmt" + "sync/atomic" + "github.com/safing/portbase/modules" "github.com/safing/portmaster/base/info" - "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/service/mgr" ) +type Info struct { + instance instance +} + var showVersion bool func init() { - modules.Register("info", prep, nil, nil) - flag.BoolVar(&showVersion, "version", false, "show version and exit") } -func prep() error { +func (i *Info) Start(m *mgr.Manager) error { err := info.CheckVersion() if err != nil { return err @@ -28,6 +33,10 @@ func prep() error { return nil } +func (i *Info) Stop(m *mgr.Manager) error { + return nil +} + // printVersion prints the version, if requested, and returns if it did so. func printVersion() (printed bool) { if showVersion { @@ -36,3 +45,18 @@ func printVersion() (printed bool) { } return false } + +var shimLoaded atomic.Bool + +func New(instance instance) (*Info, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + module := &Info{ + instance: instance, + } + + return module, nil +} + +type instance interface{} diff --git a/base/metrics/api.go b/base/metrics/api.go index fee2b8f3a..7cd3d8207 100644 --- a/base/metrics/api.go +++ b/base/metrics/api.go @@ -11,6 +11,7 @@ import ( "github.com/safing/portmaster/base/api" "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" ) func registerAPI() error { @@ -139,7 +140,7 @@ func writeMetricsTo(ctx context.Context, url string) error { ) } -func metricsWriter(ctx context.Context) error { +func metricsWriter(ctx mgr.WorkerCtx) error { pushURL := pushOption() ticker := module.NewSleepyTicker(1*time.Minute, 0) defer ticker.Stop() @@ -149,7 +150,7 @@ func metricsWriter(ctx context.Context) error { case <-ctx.Done(): return nil case <-ticker.Wait(): - err := writeMetricsTo(ctx, pushURL) + err := writeMetricsTo(ctx.Ctx(), pushURL) if err != nil { return err } diff --git a/base/metrics/module.go b/base/metrics/module.go index 96ed9563a..1d4eb2ce8 100644 --- a/base/metrics/module.go +++ b/base/metrics/module.go @@ -5,12 +5,32 @@ import ( "fmt" "sort" "sync" + "sync/atomic" - "github.com/safing/portmaster/base/modules" + "github.com/safing/portbase/modules" + "github.com/safing/portmaster/service/mgr" ) +type Metrics struct { + mgr *mgr.Manager + instance instance +} + +func (met *Metrics) Start(m *mgr.Manager) error { + met.mgr = m + if err := prepConfig(); err != nil { + return err + } + return start() +} + +func (met *Metrics) Stop(m *mgr.Manager) error { + return stop() +} + var ( - module *modules.Module + module *Metrics + shimLoaded atomic.Bool registry []Metric registryLock sync.RWMutex @@ -34,14 +54,6 @@ var ( ErrInvalidOptions = errors.New("invalid options") ) -func init() { - module = modules.Register("metrics", prep, start, stop, "config", "database", "api") -} - -func prep() error { - return prepConfig() -} - func start() error { // Add metric instance name as global variable if set. if instanceOption() != "" { @@ -71,7 +83,7 @@ func start() error { } if pushOption() != "" { - module.StartServiceWorker("metric pusher", 0, metricsWriter) + module.mgr.Do("metric pusher", metricsWriter) } return nil @@ -169,3 +181,17 @@ type byLabeledID []Metric func (r byLabeledID) Len() int { return len(r) } func (r byLabeledID) Less(i, j int) bool { return r[i].LabeledID() < r[j].LabeledID() } func (r byLabeledID) Swap(i, j int) { r[i], r[j] = r[j], r[i] } + +func New(instance instance) (*Metrics, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &Metrics{ + instance: instance, + } + + return module, nil +} + +type instance interface{} diff --git a/service/firewall/module.go b/service/firewall/module.go index 9430b48af..71c9a8307 100644 --- a/service/firewall/module.go +++ b/service/firewall/module.go @@ -6,12 +6,12 @@ import ( "fmt" "path/filepath" "strings" + "sync/atomic" "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" - "github.com/safing/portmaster/base/modules/subsystems" _ "github.com/safing/portmaster/service/core" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/network" "github.com/safing/portmaster/service/profile" "github.com/safing/portmaster/spn/access" @@ -29,66 +29,55 @@ func (ss *stringSliceFlag) Set(value string) error { return nil } -var ( - module *modules.Module - allowedClients stringSliceFlag -) +// module *modules.Module +var allowedClients stringSliceFlag -func init() { - module = modules.Register("filter", prep, start, stop, "core", "interception", "intel", "netquery") - subsystems.Register( - "filter", - "Privacy Filter", - "DNS and Network Filter", - module, - "config:filter/", - nil, - ) +type Filter struct { + mgr *mgr.Manager + + instance instance +} +func init() { flag.Var(&allowedClients, "allowed-clients", "A list of binaries that are allowed to connect to the Portmaster API") } +func (f *Filter) Start(mgr *mgr.Manager) error { + f.mgr = mgr + + if err := prep(); err != nil { + return err + } + + return start() +} + +func (f *Filter) Stop(mgr *mgr.Manager) error { + return stop() +} + func prep() error { network.SetDefaultFirewallHandler(defaultFirewallHandler) // Reset connections every time configuration changes // this will be triggered on spn enable/disable - err := module.RegisterEventHook( - "config", - config.ChangeEvent, - "reset connection verdicts after global config change", - func(ctx context.Context, _ interface{}) error { - resetAllConnectionVerdicts() - return nil - }, - ) - if err != nil { - log.Errorf("filter: failed to register event hook: %s", err) - } + module.instance.Config().EventConfigChange.AddCallback("reset connection verdicts after global config change", func(w *mgr.WorkerCtx, _ struct{}) (bool, error) { + resetAllConnectionVerdicts() + return false, nil + }) - // Reset connections every time profile changes - err = module.RegisterEventHook( - "profiles", - profile.ConfigChangeEvent, - "reset connection verdicts after profile config change", - func(ctx context.Context, eventData interface{}) error { + module.instance.Profile().EventConfigChange.AddCallback("reset connection verdicts after profile config change", + func(m *mgr.WorkerCtx, profileID string) (bool, error) { // Expected event data: scoped profile ID. - profileID, ok := eventData.(string) - if !ok { - return fmt.Errorf("event data is not a string: %v", eventData) - } profileSource, profileID, ok := strings.Cut(profileID, "/") if !ok { - return fmt.Errorf("event data does not seem to be a scoped profile ID: %v", eventData) + return false, fmt.Errorf("event data does not seem to be a scoped profile ID: %v", profileID) } resetProfileConnectionVerdict(profileSource, profileID) - return nil + return false, nil }, ) - if err != nil { - log.Errorf("filter: failed to register event hook: %s", err) - } // Reset connections when spn is connected // connect and disconnecting is triggered on config change event but connecting takеs more time @@ -149,12 +138,12 @@ func start() error { getConfig() startAPIAuth() - module.StartServiceWorker("packet handler", 0, packetHandler) - module.StartServiceWorker("bandwidth update handler", 0, bandwidthUpdateHandler) + module.mgr.Go("packet handler", packetHandler) + module.mgr.Go("bandwidth update handler", bandwidthUpdateHandler) // Start stat logger if logging is set to trace. if log.GetLogLevel() == log.TraceLevel { - module.StartServiceWorker("stat logger", 0, statLogger) + module.mgr.Go("stat logger", statLogger) } return nil @@ -163,3 +152,21 @@ func start() error { func stop() error { return nil } + +var ( + module *Filter + shimLoaded atomic.Bool +) + +func New(instance instance) (*Filter, error) { + module = &Filter{ + instance: instance, + } + + return module, nil +} + +type instance interface { + Config() *config.Config + Profile() *profile.ProfileModule +} diff --git a/service/firewall/packet_handler.go b/service/firewall/packet_handler.go index faf3dceec..9e48d2468 100644 --- a/service/firewall/packet_handler.go +++ b/service/firewall/packet_handler.go @@ -17,6 +17,7 @@ import ( _ "github.com/safing/portmaster/service/core/base" "github.com/safing/portmaster/service/firewall/inspection" "github.com/safing/portmaster/service/firewall/interception" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/netquery" "github.com/safing/portmaster/service/network" @@ -720,10 +721,10 @@ func issueVerdict(conn *network.Connection, pkt packet.Packet, verdict network.V // return // } -func packetHandler(ctx context.Context) error { +func packetHandler(w *mgr.WorkerCtx) error { for { select { - case <-ctx.Done(): + case <-w.Done(): return nil case pkt := <-interception.Packets: if pkt != nil { @@ -735,16 +736,16 @@ func packetHandler(ctx context.Context) error { } } -func bandwidthUpdateHandler(ctx context.Context) error { +func bandwidthUpdateHandler(w *mgr.WorkerCtx) error { for { select { - case <-ctx.Done(): + case <-w.Done(): return nil case bwUpdate := <-interception.BandwidthUpdates: if bwUpdate != nil { // DEBUG: // log.Debugf("filter: bandwidth update: %s", bwUpdate) - updateBandwidth(ctx, bwUpdate) + updateBandwidth(w.Ctx(), bwUpdate) } else { return errors.New("received nil bandwidth update from interception") } @@ -808,10 +809,10 @@ func updateBandwidth(ctx context.Context, bwUpdate *packet.BandwidthUpdate) { } } -func statLogger(ctx context.Context) error { +func statLogger(w *mgr.WorkerCtx) error { for { select { - case <-ctx.Done(): + case <-w.Done(): return nil case <-time.After(10 * time.Second): log.Tracef( diff --git a/service/instance.go b/service/instance.go index bda1ffcb8..8f34dad2d 100644 --- a/service/instance.go +++ b/service/instance.go @@ -5,12 +5,15 @@ import ( "github.com/safing/portmaster/base/api" "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/metrics" + "github.com/safing/portmaster/service/firewall" "github.com/safing/portmaster/service/mgr" + "github.com/safing/portmaster/service/profile" "github.com/safing/portmaster/service/ui" "github.com/safing/portmaster/service/updates" ) -// Instance is an instance of a mycoria router. +// Instance is an instance of a portmaste service. type Instance struct { *mgr.Group @@ -20,9 +23,13 @@ type Instance struct { ui *ui.UI updates *updates.Updates config *config.Config + profile *profile.ProfileModule + metrics *metrics.Metrics + + filter *firewall.Filter } -// New returns a new mycoria router instance. +// New returns a new portmaster service instance. func New(version string, svcCfg *ServiceConfig) (*Instance, error) { // Create instance to pass it to modules. instance := &Instance{ @@ -30,6 +37,8 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { } var err error + + // Base modules instance.config, err = config.New(instance) if err != nil { return nil, fmt.Errorf("create config module: %w", err) @@ -38,6 +47,12 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { if err != nil { return nil, fmt.Errorf("create api module: %w", err) } + instance.metrics, err = metrics.New(instance) + if err != nil { + return nil, fmt.Errorf("create metrics module: %w", err) + } + + // Service modules instance.updates, err = updates.New(instance, svcCfg.ShutdownFunc) if err != nil { return nil, fmt.Errorf("create updates module: %w", err) @@ -46,13 +61,25 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { if err != nil { return nil, fmt.Errorf("create ui module: %w", err) } + instance.profile, err = profile.NewModule(instance) + if err != nil { + return nil, fmt.Errorf("create profile module: %w", err) + } + instance.filter, err = firewall.New(instance) + if err != nil { + return nil, fmt.Errorf("create filter module: %w", err) + } // Add all modules to instance group. instance.Group = mgr.NewGroup( instance.config, instance.api, + instance.metrics, + instance.updates, instance.ui, + instance.profile, + instance.filter, ) return instance, nil @@ -68,6 +95,11 @@ func (i *Instance) API() *api.API { return i.api } +// Metrics returns the metrics module. +func (i *Instance) Metrics() *metrics.Metrics { + return i.metrics +} + // UI returns the ui module. func (i *Instance) UI() *ui.UI { return i.ui @@ -77,3 +109,13 @@ func (i *Instance) UI() *ui.UI { func (i *Instance) Config() *config.Config { return i.config } + +// Profile returns the profile module. +func (i *Instance) Profile() *profile.ProfileModule { + return i.profile +} + +// Profile returns the profile module. +func (i *Instance) Firewall() *firewall.Filter { + return i.filter +} diff --git a/service/mgr/events.go b/service/mgr/events.go index da7c23e2d..03436dc62 100644 --- a/service/mgr/events.go +++ b/service/mgr/events.go @@ -85,7 +85,7 @@ func (em *EventMgr[T]) Submit(event T) { // Send to subscriptions. for _, sub := range em.subs { - // Check if subcription was canceled. + // Check if subscription was canceled. if sub.canceled.Load() { anyCanceled = true continue diff --git a/service/profile/active.go b/service/profile/active.go index 6de5041da..2ac053e7e 100644 --- a/service/profile/active.go +++ b/service/profile/active.go @@ -4,6 +4,8 @@ import ( "context" "sync" "time" + + "github.com/safing/portmaster/service/mgr" ) const ( @@ -53,7 +55,7 @@ func addActiveProfile(profile *Profile) { activeProfiles[profile.ScopedID()] = profile } -func cleanActiveProfiles(ctx context.Context) error { +func cleanActiveProfiles(ctx *mgr.WorkerCtx) error { for { select { case <-time.After(activeProfileCleanerTickDuration): diff --git a/service/profile/database.go b/service/profile/database.go index ce0633cef..634026748 100644 --- a/service/profile/database.go +++ b/service/profile/database.go @@ -1,7 +1,6 @@ package profile import ( - "context" "errors" "strings" @@ -10,6 +9,7 @@ import ( "github.com/safing/portmaster/base/database/query" "github.com/safing/portmaster/base/database/record" "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" ) // Database paths: @@ -40,7 +40,7 @@ func registerValidationDBHook() (err error) { } func startProfileUpdateChecker() error { - module.StartServiceWorker("update active profiles", 0, func(ctx context.Context) (err error) { + module.mgr.Go("update active profiles", func(ctx *mgr.WorkerCtx) (err error) { profilesSub, err := profileDB.Subscribe(query.New(ProfilesDBPath)) if err != nil { return err @@ -85,7 +85,7 @@ func startProfileUpdateChecker() error { activeProfile.outdated.Set() meta.MarkDeleted(scopedID) - module.TriggerEvent(DeletedEvent, scopedID) + module.EventDelete.Submit(scopedID) continue } @@ -94,7 +94,7 @@ func startProfileUpdateChecker() error { receivedProfile, err := EnsureProfile(r) if err != nil || !receivedProfile.savedInternally { activeProfile.outdated.Set() - module.TriggerEvent(ConfigChangeEvent, scopedID) + module.EventConfigChange.Submit(scopedID) } case <-ctx.Done(): return nil diff --git a/service/profile/module.go b/service/profile/module.go index aaca99d91..03532fad2 100644 --- a/service/profile/module.go +++ b/service/profile/module.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "os" + "sync/atomic" "github.com/safing/portmaster/base/database" "github.com/safing/portmaster/base/database/migration" @@ -11,13 +12,14 @@ import ( "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/modules" _ "github.com/safing/portmaster/service/core/base" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/profile/binmeta" "github.com/safing/portmaster/service/updates" ) var ( - migrations = migration.New("core:migrations/profile") - module *modules.Module + migrations = migration.New("core:migrations/profile") + // module *modules.Module updatesPath string ) @@ -28,11 +30,31 @@ const ( MigratedEvent = "profile migrated" ) -func init() { - module = modules.Register("profiles", prep, start, stop, "base", "updates") - module.RegisterEvent(ConfigChangeEvent, true) - module.RegisterEvent(DeletedEvent, true) - module.RegisterEvent(MigratedEvent, true) +type ProfileModule struct { + mgr *mgr.Manager + instance instance + + EventConfigChange *mgr.EventMgr[string] + EventDelete *mgr.EventMgr[string] + EventMigrated *mgr.EventMgr[string] +} + +func (pm *ProfileModule) Start(m *mgr.Manager) error { + pm.mgr = m + + pm.EventConfigChange = mgr.NewEventMgr[string](ConfigChangeEvent, m) + pm.EventDelete = mgr.NewEventMgr[string](DeletedEvent, m) + pm.EventMigrated = mgr.NewEventMgr[string](MigratedEvent, m) + + if err := prep(); err != nil { + return err + } + + return start() +} + +func (pm *ProfileModule) Stop(m *mgr.Manager) error { + return stop() } func prep() error { @@ -72,7 +94,7 @@ func start() error { } meta.check() - if err := migrations.Migrate(module.Ctx); err != nil { + if err := migrations.Migrate(module.mgr.Ctx()); err != nil { log.Errorf("profile: migrations failed: %s", err) } @@ -91,9 +113,9 @@ func start() error { return err } - module.StartServiceWorker("clean active profiles", 0, cleanActiveProfiles) + module.mgr.Go("clean active profiles", cleanActiveProfiles) - err = updateGlobalConfigProfile(module.Ctx, nil) + err = updateGlobalConfigProfile(module.mgr.Ctx(), nil) if err != nil { log.Warningf("profile: error during loading global profile from configuration: %s", err) } @@ -108,3 +130,22 @@ func start() error { func stop() error { return meta.Save() } + +var ( + module *ProfileModule + shimLoaded atomic.Bool +) + +func NewModule(instance instance) (*ProfileModule, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &ProfileModule{ + instance: instance, + } + + return module, nil +} + +type instance interface{} From b75ef773ba2abc1420aa584680cd96fa385b7b19 Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov Date: Tue, 11 Jun 2024 18:12:35 +0300 Subject: [PATCH 07/56] [WIP] swtich more modules --- base/api/authentication.go | 8 +- base/api/main.go | 17 +- base/api/module.go | 13 +- base/config/module.go | 5 +- base/database/dbmodule/db.go | 42 ++++- base/database/dbmodule/maintenance.go | 21 ++- base/metrics/api.go | 2 +- base/notifications/cleaner.go | 9 +- base/notifications/module-mirror.go | 5 +- base/notifications/module.go | 59 +++++-- base/rng/entropy.go | 6 +- base/rng/fullfeed.go | 5 +- base/rng/osfeeder.go | 5 +- base/rng/rng.go | 50 ++++-- base/rng/tickfeeder.go | 5 +- base/runtime/module.go | 31 +++- base/template/module.go | 216 ++++++++++++------------ service/firewall/interception/module.go | 13 +- service/instance.go | 136 ++++++++++++++- service/netenv/main.go | 75 +++++--- service/netenv/network-change.go | 6 +- service/netenv/online-status.go | 7 +- service/status/module.go | 52 ++++-- spn/access/client.go | 6 +- spn/access/module.go | 58 +++++-- spn/cabin/module.go | 40 ++++- spn/captain/client.go | 13 +- spn/captain/module.go | 76 ++++++--- spn/captain/navigation.go | 3 +- spn/crew/connect.go | 7 +- spn/crew/module.go | 37 +++- spn/crew/op_connect.go | 23 +-- spn/docks/crane.go | 9 +- spn/docks/module.go | 42 ++++- spn/navigator/update.go | 5 +- spn/terminal/control_flow.go | 3 +- 36 files changed, 783 insertions(+), 327 deletions(-) diff --git a/base/api/authentication.go b/base/api/authentication.go index a43512d1c..cdd724702 100644 --- a/base/api/authentication.go +++ b/base/api/authentication.go @@ -15,8 +15,8 @@ import ( "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/base/rng" + "github.com/safing/portmaster/service/mgr" ) const ( @@ -351,7 +351,7 @@ func checkAPIKey(r *http.Request) *AuthToken { return token } -func updateAPIKeys(_ context.Context, _ interface{}) error { +func updateAPIKeys(_ *mgr.WorkerCtx, _ struct{}) (bool, error) { apiKeysLock.Lock() defer apiKeysLock.Unlock() @@ -433,7 +433,7 @@ func updateAPIKeys(_ context.Context, _ interface{}) error { } if hasExpiredKeys { - module.StartLowPriorityMicroTask("api key cleanup", 0, func(ctx context.Context) error { + module.mgr.Go("api key cleanup", func(ctx *mgr.WorkerCtx) error { if err := config.SetConfigOption(CfgAPIKeys, validAPIKeys); err != nil { log.Errorf("api: failed to remove expired API keys: %s", err) } else { @@ -444,7 +444,7 @@ func updateAPIKeys(_ context.Context, _ interface{}) error { }) } - return nil + return false, nil } func checkSessionCookie(r *http.Request) *AuthToken { diff --git a/base/api/main.go b/base/api/main.go index 687130ff0..7693e52d1 100644 --- a/base/api/main.go +++ b/base/api/main.go @@ -6,15 +6,9 @@ import ( "flag" "os" "time" - - "github.com/safing/portmaster/base/modules" ) -var ( - module *modules.Module - - exportEndpoints bool -) +var exportEndpoints bool // API Errors. var ( @@ -23,7 +17,7 @@ var ( ) func init() { - module = modules.Register("api", prep, start, stop, "database", "config") + // module = modules.Register("api", prep, start, stop, "database", "config") flag.BoolVar(&exportEndpoints, "export-api-endpoints", false, "export api endpoint registry and exit") } @@ -59,11 +53,8 @@ func prep() error { func start() error { startServer() - _ = updateAPIKeys(module.Ctx, nil) - err := module.RegisterEventHook("config", "config change", "update API keys", updateAPIKeys) - if err != nil { - return err - } + _ = updateAPIKeys(module.mgr.Ctx(), nil) + module.instance.Config().EventConfigChange.AddCallback("update API keys", updateAPIKeys) // start api auth token cleaner if authFnSet.IsSet() { diff --git a/base/api/module.go b/base/api/module.go index 7aa71a861..740e8d4d4 100644 --- a/base/api/module.go +++ b/base/api/module.go @@ -4,16 +4,19 @@ import ( "errors" "sync/atomic" + "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/service/mgr" ) // API is the HTTP/Websockets API module. type API struct { + mgr *mgr.Manager instance instance } // Start starts the module. -func (api *API) Start(_ *mgr.Manager) error { +func (api *API) Start(m *mgr.Manager) error { + api.mgr = m return start() } @@ -24,6 +27,7 @@ func (api *API) Stop(_ *mgr.Manager) error { var ( shimLoaded atomic.Bool + module *API ) // New returns a new UI module. @@ -36,10 +40,13 @@ func New(instance instance) (*API, error) { return nil, err } - return &API{ + module = &API{ instance: instance, - }, nil + } + + return module, nil } type instance interface { + Config() *config.Config } diff --git a/base/config/module.go b/base/config/module.go index f7ed07b84..2f3031d2b 100644 --- a/base/config/module.go +++ b/base/config/module.go @@ -33,7 +33,7 @@ var ( shimLoaded atomic.Bool ) -// New returns a new UI module. +// New returns a new Config module. func New(instance instance) (*Config, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") @@ -49,5 +49,4 @@ func New(instance instance) (*Config, error) { return module, nil } -type instance interface { -} +type instance interface{} diff --git a/base/database/dbmodule/db.go b/base/database/dbmodule/db.go index 23eecf1f8..b92bf83da 100644 --- a/base/database/dbmodule/db.go +++ b/base/database/dbmodule/db.go @@ -2,23 +2,26 @@ package dbmodule import ( "errors" + "sync/atomic" "github.com/safing/portmaster/base/database" "github.com/safing/portmaster/base/dataroot" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/base/utils" + "github.com/safing/portmaster/service/mgr" ) -var ( - databaseStructureRoot *utils.DirStructure - - module *modules.Module -) +type DBModule struct { + mgr *mgr.Manager + instance instance +} -func init() { - module = modules.Register("database", prep, start, stop) +func (dbm *DBModule) Start(m *mgr.Manager) error { + module.mgr = m + return start() } +var databaseStructureRoot *utils.DirStructure + // SetDatabaseLocation sets the location of the database for initialization. Supply either a path or dir structure. func SetDatabaseLocation(dirStructureRoot *utils.DirStructure) { if databaseStructureRoot == nil { @@ -48,3 +51,26 @@ func start() error { func stop() error { return database.Shutdown() } + +var ( + module *DBModule + shimLoaded atomic.Bool +) + +func New(instance instance) (*DBModule, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + if err := prep(); err != nil { + return nil, err + } + + module = &DBModule{ + instance: instance, + } + + return module, nil +} + +type instance interface{} diff --git a/base/database/dbmodule/maintenance.go b/base/database/dbmodule/maintenance.go index 3326ebaf4..22fddac33 100644 --- a/base/database/dbmodule/maintenance.go +++ b/base/database/dbmodule/maintenance.go @@ -1,31 +1,30 @@ package dbmodule import ( - "context" "time" "github.com/safing/portmaster/base/database" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/service/mgr" ) func startMaintenanceTasks() { - module.NewTask("basic maintenance", maintainBasic).Repeat(10 * time.Minute).MaxDelay(10 * time.Minute) - module.NewTask("thorough maintenance", maintainThorough).Repeat(1 * time.Hour).MaxDelay(1 * time.Hour) - module.NewTask("record maintenance", maintainRecords).Repeat(1 * time.Hour).MaxDelay(1 * time.Hour) + module.mgr.Go("basic maintenance", maintainBasic).Repeat(10 * time.Minute).MaxDelay(10 * time.Minute) + module.mgr.Go("thorough maintenance", maintainThorough).Repeat(1 * time.Hour).MaxDelay(1 * time.Hour) + module.mgr.Go("record maintenance", maintainRecords).Repeat(1 * time.Hour).MaxDelay(1 * time.Hour) } -func maintainBasic(ctx context.Context, task *modules.Task) error { +func maintainBasic(ctx mgr.WorkerCtx) error { log.Infof("database: running Maintain") - return database.Maintain(ctx) + return database.Maintain(ctx.Ctx()) } -func maintainThorough(ctx context.Context, task *modules.Task) error { +func maintainThorough(ctx mgr.WorkerCtx) error { log.Infof("database: running MaintainThorough") - return database.MaintainThorough(ctx) + return database.MaintainThorough(ctx.Ctx()) } -func maintainRecords(ctx context.Context, task *modules.Task) error { +func maintainRecords(ctx mgr.WorkerCtx) error { log.Infof("database: running MaintainRecordStates") - return database.MaintainRecordStates(ctx) + return database.MaintainRecordStates(ctx.Ctx()) } diff --git a/base/metrics/api.go b/base/metrics/api.go index 7cd3d8207..10a7d0d85 100644 --- a/base/metrics/api.go +++ b/base/metrics/api.go @@ -140,7 +140,7 @@ func writeMetricsTo(ctx context.Context, url string) error { ) } -func metricsWriter(ctx mgr.WorkerCtx) error { +func metricsWriter(ctx *mgr.WorkerCtx) error { pushURL := pushOption() ticker := module.NewSleepyTicker(1*time.Minute, 0) defer ticker.Stop() diff --git a/base/notifications/cleaner.go b/base/notifications/cleaner.go index 62982616f..b5395ee66 100644 --- a/base/notifications/cleaner.go +++ b/base/notifications/cleaner.go @@ -1,19 +1,20 @@ package notifications import ( - "context" "time" + + "github.com/safing/portmaster/service/mgr" ) -func cleaner(ctx context.Context) error { //nolint:unparam // Conforms to worker interface - ticker := module.NewSleepyTicker(1*time.Second, 0) +func cleaner(ctx *mgr.WorkerCtx) error { //nolint:unparam // Conforms to worker interface + ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() for { select { case <-ctx.Done(): return nil - case <-ticker.Wait(): + case <-ticker.C: deleteExpiredNotifs() } } diff --git a/base/notifications/module-mirror.go b/base/notifications/module-mirror.go index 96173be4d..43df67d90 100644 --- a/base/notifications/module-mirror.go +++ b/base/notifications/module-mirror.go @@ -1,13 +1,14 @@ package notifications import ( + "github.com/safing/portbase/modules" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/service/mgr" ) // AttachToModule attaches the notification to a module and changes to the // notification will be reflected on the module failure status. -func (n *Notification) AttachToModule(m *modules.Module) { +func (n *Notification) AttachToModule(m mgr.Module) { if m == nil { log.Warningf("notifications: invalid usage: cannot attach %s to nil module", n.EventID) return diff --git a/base/notifications/module.go b/base/notifications/module.go index 839c522cd..455376944 100644 --- a/base/notifications/module.go +++ b/base/notifications/module.go @@ -1,17 +1,34 @@ package notifications import ( + "errors" "fmt" - "time" + "sync/atomic" "github.com/safing/portmaster/base/config" - "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/service/mgr" ) -var module *modules.Module +type Notifications struct { + mgr *mgr.Manager + instance instance -func init() { - module = modules.Register("notifications", prep, start, nil, "database", "config", "base") + States *mgr.StateMgr +} + +func (n *Notifications) Start(m *mgr.Manager) error { + n.mgr = m + n.States = mgr.NewStateMgr(n.mgr) + + if err := prep(); err != nil { + return err + } + + return start() +} + +func (n *Notifications) Stop(m *mgr.Manager) error { + return nil } func prep() error { @@ -26,7 +43,7 @@ func start() error { showConfigLoadingErrors() - go module.StartServiceWorker("cleaner", 1*time.Second, cleaner) + module.mgr.Go("cleaner", cleaner) return nil } @@ -37,11 +54,12 @@ func showConfigLoadingErrors() { } // Trigger a module error for more awareness. - module.Error( - "config:validation-errors-on-load", - "Invalid Settings", - "Some current settings are invalid. Please update them and restart the Portmaster.", - ) + module.States.Add(mgr.State{ + ID: "config:validation-errors-on-load", + Name: "Invalid Settings", + Message: "Some current settings are invalid. Please update them and restart the Portmaster.", + Type: mgr.StateTypeError, + }) // Send one notification per invalid setting. for _, validationError := range config.GetLoadedConfigValidationErrors() { @@ -64,3 +82,22 @@ Please update the setting and restart the Portmaster, until then the default val ) } } + +var ( + module *Notifications + shimLoaded atomic.Bool +) + +func New(instance instance) (*Notifications, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &Notifications{ + instance: instance, + } + + return module, nil +} + +type instance interface{} diff --git a/base/rng/entropy.go b/base/rng/entropy.go index b81e2bfde..7c78dc6a6 100644 --- a/base/rng/entropy.go +++ b/base/rng/entropy.go @@ -1,12 +1,12 @@ package rng import ( - "context" "encoding/binary" "github.com/tevino/abool" "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/service/mgr" ) const ( @@ -35,7 +35,7 @@ func NewFeeder() *Feeder { needsEntropy: abool.NewBool(true), buffer: container.New(), } - module.StartServiceWorker("feeder", 0, newFeeder.run) + module.mgr.Go("feeder", newFeeder.run) return newFeeder } @@ -88,7 +88,7 @@ func (f *Feeder) CloseFeeder() { close(f.input) } -func (f *Feeder) run(ctx context.Context) error { +func (f *Feeder) run(ctx *mgr.WorkerCtx) error { defer f.needsEntropy.UnSet() for { diff --git a/base/rng/fullfeed.go b/base/rng/fullfeed.go index e055f3e1d..7e3c9bfca 100644 --- a/base/rng/fullfeed.go +++ b/base/rng/fullfeed.go @@ -1,8 +1,9 @@ package rng import ( - "context" "time" + + "github.com/safing/portmaster/service/mgr" ) func getFullFeedDuration() time.Duration { @@ -17,7 +18,7 @@ func getFullFeedDuration() time.Duration { return time.Duration(secsUntilFullFeed) * time.Second } -func fullFeeder(ctx context.Context) error { +func fullFeeder(ctx *mgr.WorkerCtx) error { fullFeedDuration := getFullFeedDuration() for { diff --git a/base/rng/osfeeder.go b/base/rng/osfeeder.go index 36aa8e4d0..253eaf5f6 100644 --- a/base/rng/osfeeder.go +++ b/base/rng/osfeeder.go @@ -1,12 +1,13 @@ package rng import ( - "context" "crypto/rand" "fmt" + + "github.com/safing/portmaster/service/mgr" ) -func osFeeder(ctx context.Context) error { +func osFeeder(ctx *mgr.WorkerCtx) error { entropyBytes := minFeedEntropy / 8 feeder := NewFeeder() defer feeder.CloseFeeder() diff --git a/base/rng/rng.go b/base/rng/rng.go index fa9bf5ca3..8c73ebe8c 100644 --- a/base/rng/rng.go +++ b/base/rng/rng.go @@ -1,20 +1,25 @@ package rng import ( - "context" "crypto/aes" "crypto/cipher" "crypto/rand" "errors" "fmt" "sync" + "sync/atomic" "github.com/aead/serpent" + "github.com/safing/portmaster/service/mgr" "github.com/seehuhn/fortuna" - - "github.com/safing/portmaster/base/modules" ) +type Rng struct { + mgr *mgr.Manager + + instance instance +} + var ( rng *fortuna.Generator rngLock sync.Mutex @@ -23,13 +28,8 @@ var ( rngCipher = "aes" // Possible values: "aes", "serpent". - module *modules.Module ) -func init() { - module = modules.Register("rng", nil, start, nil) -} - func newCipher(key []byte) (cipher.Block, error) { switch rngCipher { case "aes": @@ -41,7 +41,8 @@ func newCipher(key []byte) (cipher.Block, error) { } } -func start() error { +func (r *Rng) Start(m *mgr.Manager) error { + r.mgr = m rngLock.Lock() defer rngLock.Unlock() @@ -51,7 +52,7 @@ func start() error { } // add another (async) OS rng seed - module.StartWorker("initial rng feed", func(_ context.Context) error { + m.Go("initial rng feed", func(_ *mgr.WorkerCtx) error { // get entropy from OS osEntropy := make([]byte, minFeedEntropy/8) _, err := rand.Read(osEntropy) @@ -69,13 +70,36 @@ func start() error { rngReady = true // random source: OS - module.StartServiceWorker("os rng feeder", 0, osFeeder) + m.Go("os rng feeder", osFeeder) // random source: goroutine ticks - module.StartServiceWorker("tick rng feeder", 0, tickFeeder) + m.Go("tick rng feeder", tickFeeder) // full feeder - module.StartServiceWorker("full feeder", 0, fullFeeder) + m.Go("full feeder", fullFeeder) return nil } + +func (r *Rng) Stop(m *mgr.Manager) error { + return nil +} + +var ( + module *Rng + shimLoaded atomic.Bool +) + +func New(instance instance) (*Rng, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &Rng{ + instance: instance, + } + + return module, nil +} + +type instance interface{} diff --git a/base/rng/tickfeeder.go b/base/rng/tickfeeder.go index 6dbe69ac0..42fd5e113 100644 --- a/base/rng/tickfeeder.go +++ b/base/rng/tickfeeder.go @@ -1,9 +1,10 @@ package rng import ( - "context" "encoding/binary" "time" + + "github.com/safing/portmaster/service/mgr" ) func getTickFeederTickDuration() time.Duration { @@ -31,7 +32,7 @@ func getTickFeederTickDuration() time.Duration { // tickFeeder is a really simple entropy feeder that adds the least significant bit of the current nanosecond unixtime to its pool every time it 'ticks'. // The more work the program does, the better the quality, as the internal schedular cannot immediately run the goroutine when it's ready. -func tickFeeder(ctx context.Context) error { +func tickFeeder(ctx *mgr.WorkerCtx) error { var value int64 var pushes int feeder := NewFeeder() diff --git a/base/runtime/module.go b/base/runtime/module.go index 169abc94e..6ef8dec87 100644 --- a/base/runtime/module.go +++ b/base/runtime/module.go @@ -1,21 +1,23 @@ package runtime import ( + "errors" "fmt" + "sync/atomic" "github.com/safing/portmaster/base/database" - "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/service/mgr" ) // DefaultRegistry is the default registry // that is used by the module-level API. var DefaultRegistry = NewRegistry() -func init() { - modules.Register("runtime", nil, startModule, nil, "database") +type Runtime struct { + instance instance } -func startModule() error { +func (r *Runtime) Start(m *mgr.Manager) error { _, err := database.Register(&database.Database{ Name: "runtime", Description: "Runtime database", @@ -37,8 +39,29 @@ func startModule() error { return nil } +func (r *Runtime) Stop(m *mgr.Manager) error { + return nil +} + // Register is like Registry.Register but uses // the package DefaultRegistry. func Register(key string, provider ValueProvider) (PushFunc, error) { return DefaultRegistry.Register(key, provider) } + +var ( + module *Runtime + shimLoaded atomic.Bool +) + +func New(instance instance) (*Runtime, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &Runtime{instance: instance} + + return module, nil +} + +type instance interface{} diff --git a/base/template/module.go b/base/template/module.go index a46a74edc..bbf3b71ad 100644 --- a/base/template/module.go +++ b/base/template/module.go @@ -1,111 +1,109 @@ package template -import ( - "context" - "time" - - "github.com/safing/portmaster/base/config" - "github.com/safing/portmaster/base/modules" - "github.com/safing/portmaster/base/modules/subsystems" -) - -const ( - eventStateUpdate = "state update" -) - -var module *modules.Module - -func init() { - // register module - module = modules.Register("template", prep, start, stop) // add dependencies... - subsystems.Register( - "template-subsystem", // ID - "Template Subsystem", // name - "This subsystem is a template for quick setup", // description - module, - "config:template", // key space for configuration options registered - &config.Option{ - Name: "Template Subsystem", - Key: "config:subsystems/template", - Description: "This option enables the Template Subsystem [TEMPLATE]", - OptType: config.OptTypeBool, - DefaultValue: false, - }, - ) - - // register events that other modules can subscribe to - module.RegisterEvent(eventStateUpdate, true) -} - -func prep() error { - // register options - err := config.Register(&config.Option{ - Name: "language", - Key: "template/language", - Description: "Sets the language for the template [TEMPLATE]", - OptType: config.OptTypeString, - ExpertiseLevel: config.ExpertiseLevelUser, // default - ReleaseLevel: config.ReleaseLevelStable, // default - RequiresRestart: false, // default - DefaultValue: "en", - ValidationRegex: "^[a-z]{2}$", - }) - if err != nil { - return err - } - - // register event hooks - // do this in prep() and not in start(), as we don't want to register again if module is turned off and on again - err = module.RegisterEventHook( - "template", // event source module name - "state update", // event source name - "react to state changes", // description of hook function - eventHandler, // hook function - ) - if err != nil { - return err - } - - // hint: event hooks and tasks will not be run if module isn't online - return nil -} - -func start() error { - // register tasks - module.NewTask("do something", taskFn).Queue() - - // start service worker - module.StartServiceWorker("do something", 0, serviceWorker) - - return nil -} - -func stop() error { - return nil -} - -func serviceWorker(ctx context.Context) error { - for { - select { - case <-time.After(1 * time.Second): - err := do() - if err != nil { - return err - } - case <-ctx.Done(): - return nil - } - } -} - -func taskFn(ctx context.Context, task *modules.Task) error { - return do() -} - -func eventHandler(ctx context.Context, data interface{}) error { - return do() -} - -func do() error { - return nil -} +// import ( +// "context" +// "time" + +// "github.com/safing/portmaster/base/config" +// ) + +// const ( +// eventStateUpdate = "state update" +// ) + +// var module *modules.Module + +// func init() { +// // register module +// module = modules.Register("template", prep, start, stop) // add dependencies... +// subsystems.Register( +// "template-subsystem", // ID +// "Template Subsystem", // name +// "This subsystem is a template for quick setup", // description +// module, +// "config:template", // key space for configuration options registered +// &config.Option{ +// Name: "Template Subsystem", +// Key: "config:subsystems/template", +// Description: "This option enables the Template Subsystem [TEMPLATE]", +// OptType: config.OptTypeBool, +// DefaultValue: false, +// }, +// ) + +// // register events that other modules can subscribe to +// module.RegisterEvent(eventStateUpdate, true) +// } + +// func prep() error { +// // register options +// err := config.Register(&config.Option{ +// Name: "language", +// Key: "template/language", +// Description: "Sets the language for the template [TEMPLATE]", +// OptType: config.OptTypeString, +// ExpertiseLevel: config.ExpertiseLevelUser, // default +// ReleaseLevel: config.ReleaseLevelStable, // default +// RequiresRestart: false, // default +// DefaultValue: "en", +// ValidationRegex: "^[a-z]{2}$", +// }) +// if err != nil { +// return err +// } + +// // register event hooks +// // do this in prep() and not in start(), as we don't want to register again if module is turned off and on again +// err = module.RegisterEventHook( +// "template", // event source module name +// "state update", // event source name +// "react to state changes", // description of hook function +// eventHandler, // hook function +// ) +// if err != nil { +// return err +// } + +// // hint: event hooks and tasks will not be run if module isn't online +// return nil +// } + +// func start() error { +// // register tasks +// module.NewTask("do something", taskFn).Queue() + +// // start service worker +// module.StartServiceWorker("do something", 0, serviceWorker) + +// return nil +// } + +// func stop() error { +// return nil +// } + +// func serviceWorker(ctx context.Context) error { +// for { +// select { +// case <-time.After(1 * time.Second): +// err := do() +// if err != nil { +// return err +// } +// case <-ctx.Done(): +// return nil +// } +// } +// } + +// func taskFn(ctx context.Context, task *modules.Task) error { +// return do() +// } + +// func eventHandler(ctx context.Context, data interface{}) error { +// return do() +// } + +// func do() error { +// return nil +// } diff --git a/service/firewall/interception/module.go b/service/firewall/interception/module.go index 57c5a9611..eaa013762 100644 --- a/service/firewall/interception/module.go +++ b/service/firewall/interception/module.go @@ -2,14 +2,19 @@ package interception import ( "flag" + "sync/atomic" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/service/network/packet" ) +type Interception struct { + instance instance +} + var ( - module *modules.Module + module *Interception + shimLoaded atomic.Bool // Packets is a stream of interception network packest. Packets = make(chan packet.Packet, 1000) @@ -23,7 +28,7 @@ var ( func init() { flag.BoolVar(&disableInterception, "disable-interception", false, "disable packet interception; this breaks a lot of functionality") - module = modules.Register("interception", prep, start, stop, "base", "updates", "network", "notifications", "profiles") + // module = modules.Register("interception", prep, start, stop, "base", "updates", "network", "notifications", "profiles") } func prep() error { @@ -61,3 +66,5 @@ func stop() error { return stopInterception() } + +type instance interface{} diff --git a/service/instance.go b/service/instance.go index 8f34dad2d..1ab5533f1 100644 --- a/service/instance.go +++ b/service/instance.go @@ -6,11 +6,21 @@ import ( "github.com/safing/portmaster/base/api" "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/base/metrics" + "github.com/safing/portmaster/base/notifications" + "github.com/safing/portmaster/base/rng" + "github.com/safing/portmaster/base/runtime" "github.com/safing/portmaster/service/firewall" "github.com/safing/portmaster/service/mgr" + "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/profile" + "github.com/safing/portmaster/service/status" "github.com/safing/portmaster/service/ui" "github.com/safing/portmaster/service/updates" + "github.com/safing/portmaster/spn/access" + "github.com/safing/portmaster/spn/cabin" + "github.com/safing/portmaster/spn/captain" + "github.com/safing/portmaster/spn/crew" + "github.com/safing/portmaster/spn/docks" ) // Instance is an instance of a portmaste service. @@ -19,14 +29,25 @@ type Instance struct { version string - api *api.API - ui *ui.UI + api *api.API + config *config.Config + metrics *metrics.Metrics + runtime *runtime.Runtime + notifications *notifications.Notifications + rng *rng.Rng + + access *access.Access + cabin *cabin.Cabin + captain *captain.Captain + crew *crew.Crew + docks *docks.Docks + updates *updates.Updates - config *config.Config + ui *ui.UI profile *profile.ProfileModule - metrics *metrics.Metrics - - filter *firewall.Filter + filter *firewall.Filter + netenv *netenv.NetEnv + status *status.Status } // New returns a new portmaster service instance. @@ -51,6 +72,40 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { if err != nil { return nil, fmt.Errorf("create metrics module: %w", err) } + instance.runtime, err = runtime.New(instance) + if err != nil { + return nil, fmt.Errorf("create runtime module: %w", err) + } + instance.notifications, err = notifications.New(instance) + if err != nil { + return nil, fmt.Errorf("create runtime module: %w", err) + } + instance.rng, err = rng.New(instance) + if err != nil { + return nil, fmt.Errorf("create rng module: %w", err) + } + + // SPN modules + instance.access, err = access.New(instance) + if err != nil { + return nil, fmt.Errorf("create access module: %w", err) + } + instance.cabin, err = cabin.New(instance) + if err != nil { + return nil, fmt.Errorf("create cabin module: %w", err) + } + instance.captain, err = captain.New(instance) + if err != nil { + return nil, fmt.Errorf("create captain module: %w", err) + } + instance.crew, err = crew.New(instance) + if err != nil { + return nil, fmt.Errorf("create crew module: %w", err) + } + instance.docks, err = docks.New(instance) + if err != nil { + return nil, fmt.Errorf("create docks module: %w", err) + } // Service modules instance.updates, err = updates.New(instance, svcCfg.ShutdownFunc) @@ -69,17 +124,36 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { if err != nil { return nil, fmt.Errorf("create filter module: %w", err) } + instance.netenv, err = netenv.New(instance) + if err != nil { + return nil, fmt.Errorf("create netenv module: %w", err) + } + instance.status, err = status.New(instance) + if err != nil { + return nil, fmt.Errorf("create status module: %w", err) + } // Add all modules to instance group. instance.Group = mgr.NewGroup( instance.config, instance.api, instance.metrics, + instance.runtime, + instance.notifications, + instance.rng, + + instance.access, + instance.cabin, + instance.captain, + instance.crew, + instance.docks, instance.updates, instance.ui, instance.profile, instance.filter, + instance.netenv, + instance.status, ) return instance, nil @@ -100,6 +174,46 @@ func (i *Instance) Metrics() *metrics.Metrics { return i.metrics } +// Runtime returns the runtime module. +func (i *Instance) Runtime() *runtime.Runtime { + return i.runtime +} + +// Notifications returns the notifications module. +func (i *Instance) Notifications() *notifications.Notifications { + return i.notifications +} + +// Rng returns the rng module. +func (i *Instance) Rng() *rng.Rng { + return i.rng +} + +// Access returns the access module. +func (i *Instance) Access() *access.Access { + return i.access +} + +// Cabin returns the cabin module. +func (i *Instance) Cabin() *cabin.Cabin { + return i.cabin +} + +// Captain returns the captain module. +func (i *Instance) Captain() *captain.Captain { + return i.captain +} + +// Crew returns the crew module. +func (i *Instance) Crew() *crew.Crew { + return i.crew +} + +// Crew returns the crew module. +func (i *Instance) Docks() *docks.Docks { + return i.docks +} + // UI returns the ui module. func (i *Instance) UI() *ui.UI { return i.ui @@ -119,3 +233,13 @@ func (i *Instance) Profile() *profile.ProfileModule { func (i *Instance) Firewall() *firewall.Filter { return i.filter } + +// NetEnv returns the netenv module. +func (i *Instance) NetEnv() *netenv.NetEnv { + return i.netenv +} + +// Status returns the status module. +func (i *Instance) Status() *status.Status { + return i.status +} diff --git a/service/netenv/main.go b/service/netenv/main.go index c94062a76..8f29df8b5 100644 --- a/service/netenv/main.go +++ b/service/netenv/main.go @@ -1,10 +1,13 @@ package netenv import ( + "errors" + "sync/atomic" + "github.com/tevino/abool" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/service/mgr" ) // Event Names. @@ -14,44 +17,49 @@ const ( OnlineStatusChangedEvent = "online status changed" ) -var module *modules.Module +type NetEnv struct { + instance instance -func init() { - module = modules.Register(ModuleName, prep, start, nil) - module.RegisterEvent(NetworkChangedEvent, true) - module.RegisterEvent(OnlineStatusChangedEvent, true) + EventNetworkChange *mgr.EventMgr[struct{}] + EventOnlineStatusChange *mgr.EventMgr[OnlineStatus] } -func prep() error { - checkForIPv6Stack() - - if err := registerAPIEndpoints(); err != nil { - return err - } - - if err := prepOnlineStatus(); err != nil { +func (ne *NetEnv) Start(m *mgr.Manager) error { + if err := prep(); err != nil { return err } - return prepLocation() -} - -func start() error { - module.StartServiceWorker( + m.Go( "monitor network changes", - 0, monitorNetworkChanges, ) - module.StartServiceWorker( + m.Go( "monitor online status", - 0, monitorOnlineStatus, ) return nil } +func (ne *NetEnv) Stop(m *mgr.Manager) error { + return nil +} + +func prep() error { + checkForIPv6Stack() + + if err := registerAPIEndpoints(); err != nil { + return err + } + + if err := prepOnlineStatus(); err != nil { + return err + } + + return prepLocation() +} + var ipv6Enabled = abool.NewBool(true) // IPv6Enabled returns whether the device has an active IPv6 stack. @@ -70,3 +78,26 @@ func checkForIPv6Stack() { // Set IPv6 as enabled if any IPv6 addresses are found. ipv6Enabled.SetTo(len(v6IPs) > 0) } + +var ( + module *NetEnv + shimLoaded atomic.Bool +) + +// New returns a new UI module. +func New(instance instance) (*NetEnv, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + if err := prep(); err != nil { + return nil, err + } + + module = &NetEnv{ + instance: instance, + } + return module, nil +} + +type instance interface{} diff --git a/service/netenv/network-change.go b/service/netenv/network-change.go index 143256f90..f94139872 100644 --- a/service/netenv/network-change.go +++ b/service/netenv/network-change.go @@ -2,13 +2,13 @@ package netenv import ( "bytes" - "context" "crypto/sha1" "io" "time" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/utils" + "github.com/safing/portmaster/service/mgr" ) var ( @@ -23,7 +23,7 @@ func GetNetworkChangedFlag() *utils.Flag { func notifyOfNetworkChange() { networkChangedBroadcastFlag.NotifyAndReset() - module.TriggerEvent(NetworkChangedEvent, nil) + module.EventNetworkChange.Submit(struct{}{}) } // TriggerNetworkChangeCheck triggers a network change check. @@ -34,7 +34,7 @@ func TriggerNetworkChangeCheck() { } } -func monitorNetworkChanges(ctx context.Context) error { +func monitorNetworkChanges(ctx *mgr.WorkerCtx) error { var lastNetworkChecksum []byte serviceLoop: diff --git a/service/netenv/online-status.go b/service/netenv/online-status.go index 018d3dc16..554fc0046 100644 --- a/service/netenv/online-status.go +++ b/service/netenv/online-status.go @@ -15,6 +15,7 @@ import ( "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/notifications" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/network/netutils" "github.com/safing/portmaster/service/updates" ) @@ -207,7 +208,7 @@ func updateOnlineStatus(status OnlineStatus, portalURL *url.URL, comment string) // Trigger events. if changed { - module.TriggerEvent(OnlineStatusChangedEvent, status) + module.EventOnlineStatusChange.Submit(status) if status == StatusPortal { log.Infof(`netenv: setting online status to %s at "%s" (%s)`, status, portalURL, comment) } else { @@ -356,7 +357,7 @@ func TriggerOnlineStatusInvestigation() { } } -func monitorOnlineStatus(ctx context.Context) error { +func monitorOnlineStatus(ctx *mgr.WorkerCtx) error { TriggerOnlineStatusInvestigation() for { // wait for trigger @@ -372,7 +373,7 @@ func monitorOnlineStatus(ctx context.Context) error { onlineStatusInvestigationWg.Add(1) } - checkOnlineStatus(ctx) + checkOnlineStatus(ctx.Ctx()) // finished! onlineStatusInvestigationWg.Done() diff --git a/service/status/module.go b/service/status/module.go index bb6d29fb3..bd414ffb6 100644 --- a/service/status/module.go +++ b/service/status/module.go @@ -1,36 +1,35 @@ package status import ( - "context" + "errors" "fmt" + "sync/atomic" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/base/utils/debug" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" ) -var module *modules.Module - -func init() { - module = modules.Register("status", nil, start, nil, "base", "config") +type Status struct { + instance instance } -func start() error { +func (s *Status) Start(m *mgr.Manager) error { if err := setupRuntimeProvider(); err != nil { return err } - if err := module.RegisterEventHook( - netenv.ModuleName, - netenv.OnlineStatusChangedEvent, - "update online status in system status", - func(_ context.Context, _ interface{}) error { + s.instance.NetEnv().EventOnlineStatusChange.AddCallback("update online status in system status", + func(_ *mgr.WorkerCtx, _ netenv.OnlineStatus) (bool, error) { pushSystemStatus() - return nil + return false, nil }, - ); err != nil { - return err - } + ) + + return nil +} + +func (s *Status) Stop(m *mgr.Manager) error { return nil } @@ -43,3 +42,24 @@ func AddToDebugInfo(di *debug.Info) { "CaptivePortal: "+netenv.GetCaptivePortal().URL, ) } + +var ( + module *Status + shimLoaded atomic.Bool +) + +func New(instance instance) (*Status, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &Status{ + instance: instance, + } + + return module, nil +} + +type instance interface { + NetEnv() *netenv.NetEnv +} diff --git a/spn/access/client.go b/spn/access/client.go index 0e08c8880..b52e88433 100644 --- a/spn/access/client.go +++ b/spn/access/client.go @@ -214,7 +214,7 @@ func Login(username, password string) (user *UserRecord, code int, err error) { defer clientRequestLock.Unlock() // Trigger account update when done. - defer module.TriggerEvent(AccountUpdateEvent, nil) + defer module.EventAccountUpdate.Submit(struct{}{}) // Get previous user. previousUser, err := GetUser() @@ -300,7 +300,7 @@ func Logout(shallow, purge bool) error { defer clientRequestLock.Unlock() // Trigger account update when done. - defer module.TriggerEvent(AccountUpdateEvent, nil) + defer module.EventAccountUpdate.Submit(struct{}{}) // Clear caches. clearUserCaches() @@ -382,7 +382,7 @@ func UpdateUser() (user *UserRecord, statusCode int, err error) { defer clientRequestLock.Unlock() // Trigger account update when done. - defer module.TriggerEvent(AccountUpdateEvent, nil) + defer module.EventAccountUpdate.Submit(struct{}{}) // Create request options. userData := &account.User{} diff --git a/spn/access/module.go b/spn/access/module.go index 5779d52ed..3b6333a26 100644 --- a/spn/access/module.go +++ b/spn/access/module.go @@ -1,25 +1,47 @@ package access import ( - "context" "errors" "fmt" + "sync/atomic" "time" "github.com/tevino/abool" "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/access/account" "github.com/safing/portmaster/spn/access/token" "github.com/safing/portmaster/spn/conf" ) +type Access struct { + mgr *mgr.Manager + instance instance + + EventAccountUpdate *mgr.EventMgr[struct{}] +} + +func (a *Access) Start(m *mgr.Manager) error { + a.mgr = m + a.EventAccountUpdate = mgr.NewEventMgr[struct{}](AccountUpdateEvent, m) + if err := prep(); err != nil { + return err + } + + return start() +} + +func (a *Access) Stop(m *mgr.Manager) error { + return stop() +} + var ( - module *modules.Module + module *Access + shimLoaded atomic.Bool - accountUpdateTask *modules.Task + // accountUpdateTask *modules.Task tokenIssuerIsFailing = abool.New() tokenIssuerRetryDuration = 10 * time.Minute @@ -38,13 +60,7 @@ var ( ErrNotLoggedIn = errors.New("not logged in") ) -func init() { - module = modules.Register("access", prep, start, stop, "terminal") -} - func prep() error { - module.RegisterEvent(AccountUpdateEvent, true) - // Register API handlers. if conf.Client() { err := registerAPIEndpoints() @@ -67,7 +83,7 @@ func start() error { loadTokens() // Register new task. - accountUpdateTask = module.NewTask( + accountUpdateTask = module.mgr.Go( "update account", UpdateAccount, ).Repeat(24 * time.Hour).Schedule(time.Now().Add(1 * time.Minute)) @@ -93,7 +109,7 @@ func stop() error { } // UpdateAccount updates the user account and fetches new tokens, if needed. -func UpdateAccount(_ context.Context, task *modules.Task) error { +func UpdateAccount(_ *mgr.WorkerCtx) error { //, task *modules.Task) error { // Retry sooner if the token issuer is failing. defer func() { if tokenIssuerIsFailing.IsSet() && task != nil { @@ -192,3 +208,21 @@ func (user *UserRecord) MayUseTheSPN() bool { return user.User.MayUseSPN() } + +// New returns a new Access module. +func New(instance instance) (*Access, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + if err := prep(); err != nil { + return nil, err + } + + module = &Access{ + instance: instance, + } + return module, nil +} + +type instance interface{} diff --git a/spn/cabin/module.go b/spn/cabin/module.go index 3a1dd78e2..8e35b6128 100644 --- a/spn/cabin/module.go +++ b/spn/cabin/module.go @@ -1,16 +1,30 @@ package cabin import ( - "github.com/safing/portmaster/base/modules" + "errors" + "sync/atomic" + + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/conf" ) -var module *modules.Module +type Cabin struct { + instance instance +} + +func (c *Cabin) Start(m *mgr.Manager) error { + return prep() +} -func init() { - module = modules.Register("cabin", prep, nil, nil, "base", "rng") +func (c *Cabin) Stop(m *mgr.Manager) error { + return nil } +var ( + module *Cabin + shimLoaded atomic.Bool +) + func prep() error { if err := initProvidedExchKeySchemes(); err != nil { return err @@ -24,3 +38,21 @@ func prep() error { return nil } + +// New returns a new Cabin module. +func New(instance instance) (*Cabin, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + if err := prep(); err != nil { + return nil, err + } + + module = &Cabin{ + instance: instance, + } + return module, nil +} + +type instance interface{} diff --git a/spn/captain/client.go b/spn/captain/client.go index b1a81d413..d08727f63 100644 --- a/spn/captain/client.go +++ b/spn/captain/client.go @@ -10,6 +10,7 @@ import ( "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/notifications" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/network/netutils" "github.com/safing/portmaster/spn/access" @@ -71,13 +72,13 @@ func triggerClientHealthCheck() { } } -func clientManager(ctx context.Context) error { +func clientManager(ctx *mgr.WorkerCtx) error { defer func() { ready.UnSet() netenv.ConnectedToSPN.UnSet() resetSPNStatus(StatusDisabled, true) module.Resolve("") - clientStopHomeHub(ctx) + clientStopHomeHub(ctx.Ctx()) }() module.Hint( @@ -123,7 +124,7 @@ reconnect: clientConnectToHomeHub, clientSetActiveConnectionStatus, } { - switch clientFunc(ctx) { + switch clientFunc(ctx.Ctx()) { case clientResultOk: // Continue case clientResultRetry, clientResultReconnect: @@ -143,8 +144,8 @@ reconnect: ready.Set() netenv.ConnectedToSPN.Set() - module.TriggerEvent(SPNConnectedEvent, nil) - module.StartWorker("update quick setting countries", navigator.Main.UpdateConfigQuickSettings) + module.EventSPNConnected.Submit(struct{}{}) + module.mgr.Go("update quick setting countries", navigator.Main.UpdateConfigQuickSettings) // Reset last health check value, as we have just connected. lastHealthCheck = time.Now() @@ -164,7 +165,7 @@ reconnect: clientCheckAccountAndTokens, clientSetActiveConnectionStatus, } { - switch clientFunc(ctx) { + switch clientFunc(ctx.Ctx()) { case clientResultOk: // Continue case clientResultRetry: diff --git a/spn/captain/module.go b/spn/captain/module.go index 9538f5b7a..55299f292 100644 --- a/spn/captain/module.go +++ b/spn/captain/module.go @@ -6,14 +6,17 @@ import ( "fmt" "net" "net/http" + "sync/atomic" "time" + "github.com/safing/portbase/modules/subsystems" "github.com/safing/portmaster/base/api" "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/modules" - "github.com/safing/portmaster/base/modules/subsystems" "github.com/safing/portmaster/base/rng" + "github.com/safing/portmaster/service/mgr" + "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/network/netutils" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/crew" @@ -25,14 +28,31 @@ import ( const controlledFailureExitCode = 24 -var module *modules.Module - // SPNConnectedEvent is the name of the event that is fired when the SPN has connected and is ready. const SPNConnectedEvent = "spn connect" +type Captain struct { + mgr *mgr.Manager + instance instance + + EventSPNConnected *mgr.EventMgr[struct{}] +} + +func (c *Captain) Start(m *mgr.Manager) error { + c.mgr = m + c.EventSPNConnected = mgr.NewEventMgr[struct{}](SPNConnectedEvent, m) + if err := prep(); err != nil { + return err + } + + return start() +} + +func (c *Captain) Stop(m *mgr.Manager) error { + return stop() +} + func init() { - module = modules.Register("captain", prep, start, stop, "base", "terminal", "cabin", "ships", "docks", "crew", "navigator", "sluice", "patrol", "netenv") - module.RegisterEvent(SPNConnectedEvent, false) subsystems.Register( "spn", "SPN", @@ -102,7 +122,7 @@ func start() error { if err := registerIntelUpdateHook(); err != nil { return err } - if err := updateSPNIntel(module.Ctx, nil); err != nil { + if err := updateSPNIntel(module.mgr.Ctx(), nil); err != nil { log.Errorf("spn/captain: failed to update SPN intel: %s", err) } @@ -152,29 +172,22 @@ func start() error { // network optimizer if conf.PublicHub() { - module.NewTask("optimize network", optimizeNetwork). + module.mgr.Go("optimize network", optimizeNetwork). Repeat(1 * time.Minute). Schedule(time.Now().Add(15 * time.Second)) } // client + home hub manager if conf.Client() { - module.StartServiceWorker("client manager", 0, clientManager) + module.mgr.Go("client manager", clientManager) // Reset failing hubs when the network changes while not connected. - if err := module.RegisterEventHook( - "netenv", - "network changed", - "reset failing hubs", - func(_ context.Context, _ interface{}) error { - if ready.IsNotSet() { - navigator.Main.ResetFailingStates(module.Ctx) - } - return nil - }, - ); err != nil { - return err - } + module.instance.NetEnv().EventNetworkChange.AddCallback("reset failing hubs", func(_ *mgr.WorkerCtx, _ struct{}) (bool, error) { + if ready.IsNotSet() { + navigator.Main.ResetFailingStates(module.mgr.Ctx()) + } + return false, nil + }) } return nil @@ -217,3 +230,24 @@ func apiAuthenticator(r *http.Request, s *http.Server) (*api.AuthToken, error) { Write: api.PermitUser, }, nil } + +var ( + module *Captain + shimLoaded atomic.Bool +) + +// New returns a new Captain module. +func New(instance instance) (*Captain, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &Captain{ + instance: instance, + } + return module, nil +} + +type instance interface { + NetEnv() *netenv.NetEnv +} diff --git a/spn/captain/navigation.go b/spn/captain/navigation.go index 77e77c078..5b6210b73 100644 --- a/spn/captain/navigation.go +++ b/spn/captain/navigation.go @@ -9,6 +9,7 @@ import ( "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/service/intel" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/profile/endpoints" "github.com/safing/portmaster/spn/access" @@ -210,7 +211,7 @@ func connectToHomeHub(ctx context.Context, dst *hub.Hub) error { return nil } -func optimizeNetwork(ctx context.Context, task *modules.Task) error { +func optimizeNetwork(ctx *mgr.WorkerCtx) error { //, task *modules.Task) error { if publicIdentity == nil { return nil } diff --git a/spn/crew/connect.go b/spn/crew/connect.go index ca11080c5..7b1edb4be 100644 --- a/spn/crew/connect.go +++ b/spn/crew/connect.go @@ -10,6 +10,7 @@ import ( "time" "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/network" "github.com/safing/portmaster/service/profile/endpoints" "github.com/safing/portmaster/spn/access" @@ -37,7 +38,7 @@ func HandleSluiceRequest(connInfo *network.Connection, conn net.Conn) { connInfo: connInfo, conn: conn, } - module.StartWorker("tunnel handler", t.connectWorker) + module.mgr.Go("tunnel handler", t.connectWorker) } // Tunnel represents the local information and endpoint of a data tunnel. @@ -52,9 +53,9 @@ type Tunnel struct { stickied bool } -func (t *Tunnel) connectWorker(ctx context.Context) (err error) { +func (t *Tunnel) connectWorker(wc *mgr.WorkerCtx) (err error) { // Get tracing logger. - ctx, tracer := log.AddTracer(ctx) + ctx, tracer := log.AddTracer(wc.Ctx()) defer tracer.Submit() // Save start time. diff --git a/spn/crew/module.go b/spn/crew/module.go index 5bbf89294..2a677435d 100644 --- a/spn/crew/module.go +++ b/spn/crew/module.go @@ -1,16 +1,26 @@ package crew import ( + "errors" + "sync/atomic" "time" - "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/terminal" ) -var module *modules.Module +type Crew struct { + mgr *mgr.Manager + instance instance +} + +func (c *Crew) Start(m *mgr.Manager) error { + c.mgr = m + return start() +} -func init() { - module = modules.Register("crew", nil, start, stop, "terminal", "docks", "navigator", "intel", "cabin") +func (c *Crew) Stop(m *mgr.Manager) error { + return stop() } func start() error { @@ -42,3 +52,22 @@ func reportConnectError(tErr *terminal.Error) { func ConnectErrors() <-chan *terminal.Error { return connectErrors } + +var ( + module *Crew + shimLoaded atomic.Bool +) + +// New returns a new Config module. +func New(instance instance) (*Crew, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &Crew{ + instance: instance, + } + return module, nil +} + +type instance interface{} diff --git a/spn/crew/op_connect.go b/spn/crew/op_connect.go index 228047b79..82079d9f3 100644 --- a/spn/crew/op_connect.go +++ b/spn/crew/op_connect.go @@ -13,6 +13,7 @@ import ( "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/network/netutils" "github.com/safing/portmaster/service/network/packet" "github.com/safing/portmaster/spn/conf" @@ -141,7 +142,7 @@ func NewConnectOp(tunnel *Tunnel) (*ConnectOp, *terminal.Error) { entry: true, tunnel: tunnel, } - op.ctx, op.cancelCtx = context.WithCancel(module.Ctx) + op.ctx, op.cancelCtx = context.WithCancel(module.mgr.Ctx()) op.dfq = terminal.NewDuplexFlowQueue(op.Ctx(), request.QueueSize, op.submitUpstream) // Prepare init msg. @@ -159,9 +160,9 @@ func NewConnectOp(tunnel *Tunnel) (*ConnectOp, *terminal.Error) { // Setup metrics. op.started = time.Now() - module.StartWorker("connect op conn reader", op.connReader) - module.StartWorker("connect op conn writer", op.connWriter) - module.StartWorker("connect op flow handler", op.dfq.FlowHandler) + module.mgr.Go("connect op conn reader", op.connReader) + module.mgr.Go("connect op conn writer", op.connWriter) + module.mgr.Go("connect op flow handler", op.dfq.FlowHandler) log.Infof("spn/crew: connected to %s via %s", request, tunnel.dstPin.Hub) return op, nil @@ -202,12 +203,12 @@ func startConnectOp(t terminal.Terminal, opID uint32, data *container.Container) op.dfq = terminal.NewDuplexFlowQueue(op.Ctx(), request.QueueSize, op.submitUpstream) // Start worker to complete setting up the connection. - module.StartWorker("connect op setup", op.handleSetup) + module.mgr.Go("connect op setup", op.handleSetup) return op, nil } -func (op *ConnectOp) handleSetup(_ context.Context) error { +func (op *ConnectOp) handleSetup(_ *mgr.WorkerCtx) error { // Get terminal session for rate limiting. var session *terminal.Session if sessionTerm, ok := op.t.(terminal.SessionTerminal); ok { @@ -309,9 +310,9 @@ func (op *ConnectOp) setup(session *terminal.Session) { op.conn = conn // Start worker. - module.StartWorker("connect op conn reader", op.connReader) - module.StartWorker("connect op conn writer", op.connWriter) - module.StartWorker("connect op flow handler", op.dfq.FlowHandler) + module.mgr.Go("connect op conn reader", op.connReader) + module.mgr.Go("connect op conn writer", op.connWriter) + module.mgr.Go("connect op flow handler", op.dfq.FlowHandler) connectOpCntConnected.Inc() log.Infof("spn/crew: connected op %s#%d to %s", op.t.FmtID(), op.ID(), op.request) @@ -337,7 +338,7 @@ const ( rateLimitMaxMbit = 128 ) -func (op *ConnectOp) connReader(_ context.Context) error { +func (op *ConnectOp) connReader(_ *mgr.WorkerCtx) error { // Metrics setup and submitting. atomic.AddInt64(activeConnectOps, 1) defer func() { @@ -403,7 +404,7 @@ func (op *ConnectOp) Deliver(msg *terminal.Msg) *terminal.Error { return op.dfq.Deliver(msg) } -func (op *ConnectOp) connWriter(_ context.Context) error { +func (op *ConnectOp) connWriter(_ *mgr.WorkerCtx) error { // Metrics submitting. defer func() { connectOpOutgoingDataHistogram.Update(float64(op.outgoingTraffic.Load())) diff --git a/spn/docks/crane.go b/spn/docks/crane.go index c7e65b41b..5d1894462 100644 --- a/spn/docks/crane.go +++ b/spn/docks/crane.go @@ -16,6 +16,7 @@ import ( "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/rng" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/cabin" "github.com/safing/portmaster/spn/hub" "github.com/safing/portmaster/spn/ships" @@ -110,7 +111,7 @@ type Crane struct { // NewCrane returns a new crane. func NewCrane(ship ships.Ship, connectedHub *hub.Hub, id *cabin.Identity) (*Crane, error) { // Cranes always run in module context. - ctx, cancelCtx := context.WithCancel(module.Ctx) + ctx, cancelCtx := context.WithCancel(module.mgr.Ctx()) newCrane := &Crane{ ctx: ctx, @@ -351,7 +352,7 @@ func (crane *Crane) AbandonTerminal(id uint32, err *terminal.Error) { if crane.stopping.IsSet() && crane.terminalCount() <= 1 { // Stop the crane in worker, so the caller can do some work. - module.StartWorker("retire crane", func(_ context.Context) error { + module.mgr.Go("retire crane", func(_ *mgr.WorkerCtx) error { // Let enough time for the last errors to be sent, as terminals are abandoned in a goroutine. time.Sleep(3 * time.Second) crane.Stop(nil) @@ -618,7 +619,7 @@ handling: if deliveryErr != nil { msg.Finish() // This is a hot path. Start a worker for abandoning the terminal. - module.StartWorker("end terminal", func(_ context.Context) error { + module.mgr.Go("end terminal", func(_ *mgr.WorkerCtx) error { crane.AbandonTerminal(t.ID(), deliveryErr.Wrap("failed to deliver data")) return nil }) @@ -635,7 +636,7 @@ handling: receivedErr = terminal.ErrUnknownError.AsExternal() } // This is a hot path. Start a worker for abandoning the terminal. - module.StartWorker("end terminal", func(_ context.Context) error { + module.mgr.Go("end terminal", func(_ *mgr.WorkerCtx) error { crane.AbandonTerminal(terminalID, receivedErr) return nil }) diff --git a/spn/docks/module.go b/spn/docks/module.go index f79b27cf3..2c2f17b97 100644 --- a/spn/docks/module.go +++ b/spn/docks/module.go @@ -5,15 +5,28 @@ import ( "errors" "fmt" "sync" + "sync/atomic" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/base/rng" + "github.com/safing/portmaster/service/mgr" _ "github.com/safing/portmaster/spn/access" ) -var ( - module *modules.Module +type Docks struct { + mgr *mgr.Manager + instance instance +} +func (d *Docks) Start(m *mgr.Manager) error { + d.mgr = m + return start() +} + +func (d *Docks) Stop(m *mgr.Manager) error { + return stopAllCranes() +} + +var ( allCranes = make(map[string]*Crane) // ID = Crane ID assignedCranes = make(map[string]*Crane) // ID = connected Hub ID cranesLock sync.RWMutex @@ -21,10 +34,6 @@ var ( runningTests bool ) -func init() { - module = modules.Register("docks", nil, start, stopAllCranes, "terminal", "cabin", "access") -} - func start() error { return registerMetrics() } @@ -115,3 +124,22 @@ func GetAllAssignedCranes() map[string]*Crane { } return copiedCranes } + +var ( + module *Docks + shimLoaded atomic.Bool +) + +// New returns a new Docks module. +func New(instance instance) (*Docks, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &Docks{ + instance: instance, + } + return module, nil +} + +type instance interface{} diff --git a/spn/navigator/update.go b/spn/navigator/update.go index ee2ed1956..b04a58691 100644 --- a/spn/navigator/update.go +++ b/spn/navigator/update.go @@ -18,6 +18,7 @@ import ( "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/base/utils" "github.com/safing/portmaster/service/intel/geoip" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/profile" "github.com/safing/portmaster/spn/hub" @@ -558,8 +559,8 @@ func (m *Map) addBootstrapHub(bootstrapTransport string) error { } // UpdateConfigQuickSettings updates config quick settings with available countries. -func (m *Map) UpdateConfigQuickSettings(ctx context.Context) error { - ctx, tracer := log.AddTracer(ctx) +func (m *Map) UpdateConfigQuickSettings(wc *mgr.WorkerCtx) error { + ctx, tracer := log.AddTracer(wc.Ctx()) tracer.Trace("navigator: updating SPN rules country quick settings") defer tracer.Submit() diff --git a/spn/terminal/control_flow.go b/spn/terminal/control_flow.go index b8572dda0..24685206a 100644 --- a/spn/terminal/control_flow.go +++ b/spn/terminal/control_flow.go @@ -9,6 +9,7 @@ import ( "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/service/mgr" ) // FlowControl defines the flow control interface. @@ -187,7 +188,7 @@ func (dfq *DuplexFlowQueue) reportableRecvSpace() int32 { // FlowHandler handles all flow queue internals and must be started as a worker // in the module where it is used. -func (dfq *DuplexFlowQueue) FlowHandler(_ context.Context) error { +func (dfq *DuplexFlowQueue) FlowHandler(_ *mgr.WorkerCtx) error { // The upstreamSender is started by the terminal module, but is tied to the // flow owner instead. Make sure that the flow owner's module depends on the // terminal module so that it is shut down earlier. From 14991aee9602223c5b305b71cc2d834307a83e80 Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov Date: Thu, 13 Jun 2024 11:11:11 +0300 Subject: [PATCH 08/56] [WIP] switch all SPN modules --- service/instance.go | 72 +++++++++++++++++++++++++++++++++++---- spn/navigator/module.go | 43 +++++++++++++++++++---- spn/patrol/module.go | 40 ++++++++++++++++------ spn/ships/module.go | 36 ++++++++++++++++---- spn/sluice/module.go | 44 ++++++++++++++++++++---- spn/terminal/module.go | 41 +++++++++++++++++++--- spn/terminal/operation.go | 4 +-- spn/unit/scheduler.go | 3 +- 8 files changed, 240 insertions(+), 43 deletions(-) diff --git a/service/instance.go b/service/instance.go index 1ab5533f1..9a86401a5 100644 --- a/service/instance.go +++ b/service/instance.go @@ -21,6 +21,11 @@ import ( "github.com/safing/portmaster/spn/captain" "github.com/safing/portmaster/spn/crew" "github.com/safing/portmaster/spn/docks" + "github.com/safing/portmaster/spn/navigator" + "github.com/safing/portmaster/spn/patrol" + "github.com/safing/portmaster/spn/ships" + "github.com/safing/portmaster/spn/sluice" + "github.com/safing/portmaster/spn/terminal" ) // Instance is an instance of a portmaste service. @@ -36,11 +41,16 @@ type Instance struct { notifications *notifications.Notifications rng *rng.Rng - access *access.Access - cabin *cabin.Cabin - captain *captain.Captain - crew *crew.Crew - docks *docks.Docks + access *access.Access + cabin *cabin.Cabin + captain *captain.Captain + crew *crew.Crew + docks *docks.Docks + navigator *navigator.Navigator + patrol *patrol.Patrol + ships *ships.Ships + sluice *sluice.SluiceModule + terminal *terminal.TerminalModule updates *updates.Updates ui *ui.UI @@ -106,6 +116,26 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { if err != nil { return nil, fmt.Errorf("create docks module: %w", err) } + instance.navigator, err = navigator.New(instance) + if err != nil { + return nil, fmt.Errorf("create navigator module: %w", err) + } + instance.patrol, err = patrol.New(instance) + if err != nil { + return nil, fmt.Errorf("create patrol module: %w", err) + } + instance.ships, err = ships.New(instance) + if err != nil { + return nil, fmt.Errorf("create ships module: %w", err) + } + instance.sluice, err = sluice.New(instance) + if err != nil { + return nil, fmt.Errorf("create sluice module: %w", err) + } + instance.terminal, err = terminal.New(instance) + if err != nil { + return nil, fmt.Errorf("create terminal module: %w", err) + } // Service modules instance.updates, err = updates.New(instance, svcCfg.ShutdownFunc) @@ -147,6 +177,11 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { instance.captain, instance.crew, instance.docks, + instance.navigator, + instance.patrol, + instance.ships, + instance.sluice, + instance.terminal, instance.updates, instance.ui, @@ -209,11 +244,36 @@ func (i *Instance) Crew() *crew.Crew { return i.crew } -// Crew returns the crew module. +// Docks returns the crew module. func (i *Instance) Docks() *docks.Docks { return i.docks } +// Navigator returns the navigator module. +func (i *Instance) Navigator() *navigator.Navigator { + return i.navigator +} + +// Patrol returns the patrol module. +func (i *Instance) Patrol() *patrol.Patrol { + return i.patrol +} + +// Ships returns the ships module. +func (i *Instance) Ships() *ships.Ships { + return i.ships +} + +// Sluice returns the ships module. +func (i *Instance) Sluice() *sluice.SluiceModule { + return i.sluice +} + +// Terminal returns the terminal module. +func (i *Instance) Terminal() *terminal.TerminalModule { + return i.terminal +} + // UI returns the ui module. func (i *Instance) UI() *ui.UI { return i.ui diff --git a/spn/navigator/module.go b/spn/navigator/module.go index 3e8c6fc10..c8027fb90 100644 --- a/spn/navigator/module.go +++ b/spn/navigator/module.go @@ -2,12 +2,13 @@ package navigator import ( "errors" + "sync/atomic" "time" "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/service/intel/geoip" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/conf" ) @@ -33,8 +34,28 @@ var ( ErrAllPinsDisregarded = errors.New("all pins have been disregarded") ) +type Navigator struct { + mgr *mgr.Manager + + instance instance +} + +func (n *Navigator) Start(m *mgr.Manager) error { + n.mgr = m + if err := prep(); err != nil { + return err + } + + return start() +} + +func (n *Navigator) Stop(m *mgr.Manager) error { + return stop() +} + var ( - module *modules.Module + module *Navigator + shimLoaded atomic.Bool // Main is the primary map used. Main *Map @@ -44,10 +65,6 @@ var ( cfgOptionTrustNodeNodes config.StringArrayOption ) -func init() { - module = modules.Register("navigator", prep, start, stop, "terminal", "geoip", "netenv") -} - func prep() error { return registerAPIEndpoints() } @@ -127,3 +144,17 @@ func stop() error { return nil } + +// New returns a new Navigator module. +func New(instance instance) (*Navigator, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &Navigator{ + instance: instance, + } + return module, nil +} + +type instance interface{} diff --git a/spn/patrol/module.go b/spn/patrol/module.go index 9a66cee5d..0962f17fa 100644 --- a/spn/patrol/module.go +++ b/spn/patrol/module.go @@ -1,32 +1,52 @@ package patrol import ( + "errors" + "sync/atomic" "time" - "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/conf" ) // ChangeSignalEventName is the name of the event that signals any change in the patrol system. const ChangeSignalEventName = "change signal" -var module *modules.Module +type Patrol struct { + instance instance -func init() { - module = modules.Register("patrol", prep, start, nil, "rng") + EventChangeSignal *mgr.EventMgr[struct{}] } -func prep() error { - module.RegisterEvent(ChangeSignalEventName, false) +func (p *Patrol) Start(m *mgr.Manager) error { + p.EventChangeSignal = mgr.NewEventMgr[struct{}](ChangeSignalEventName, m) - return nil -} - -func start() error { if conf.PublicHub() { module.NewTask("connectivity test", connectivityCheckTask). Repeat(5 * time.Minute) } + return nil +} +func (p *Patrol) Stop(m *mgr.Manager) error { return nil } + +var ( + module *Patrol + shimLoaded atomic.Bool +) + +// New returns a new Config module. +func New(instance instance) (*Patrol, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &Patrol{ + instance: instance, + } + return module, nil +} + +type instance interface{} diff --git a/spn/ships/module.go b/spn/ships/module.go index 01543ac27..27b1fbc3b 100644 --- a/spn/ships/module.go +++ b/spn/ships/module.go @@ -1,20 +1,44 @@ package ships import ( - "github.com/safing/portmaster/base/modules" + "errors" + "sync/atomic" + + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/conf" ) -var module *modules.Module - -func init() { - module = modules.Register("ships", start, nil, nil, "cabin") +type Ships struct { + instance instance } -func start() error { +func (s *Ships) Start(m *mgr.Manager) error { if conf.PublicHub() { initPageInput() } return nil } + +func (s *Ships) Stop(m *mgr.Manager) error { + return nil +} + +var ( + module *Ships + shimLoaded atomic.Bool +) + +// New returns a new Config module. +func New(instance instance) (*Ships, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &Ships{ + instance: instance, + } + return module, nil +} + +type instance interface{} diff --git a/spn/sluice/module.go b/spn/sluice/module.go index 6ca15af1b..9c2c4cbf5 100644 --- a/spn/sluice/module.go +++ b/spn/sluice/module.go @@ -1,25 +1,36 @@ package sluice import ( + "errors" + "sync/atomic" + "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/spn/conf" ) -var ( - module *modules.Module +type SluiceModule struct { + mgr *mgr.Manager + instance instance +} + +func (s *SluiceModule) Start(m *mgr.Manager) error { + s.mgr = m + return start() +} +func (s *SluiceModule) Stop(_ *mgr.Manager) error { + return stop() +} + +var ( entrypointInfoMsg = []byte("You have reached the local SPN entry port, but your connection could not be matched to an SPN tunnel.\n") // EnableListener indicates if it should start the sluice listeners. Must be set at startup. EnableListener bool = true ) -func init() { - module = modules.Register("sluice", nil, start, stop, "terminal") -} - func start() error { // TODO: // Listening on all interfaces for now, as we need this for Windows. @@ -44,3 +55,22 @@ func stop() error { stopAllSluices() return nil } + +var ( + module *SluiceModule + shimLoaded atomic.Bool +) + +// New returns a new Config module. +func New(instance instance) (*SluiceModule, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &SluiceModule{ + instance: instance, + } + return module, nil +} + +type instance interface{} diff --git a/spn/terminal/module.go b/spn/terminal/module.go index 01cd6d066..3caeaf82a 100644 --- a/spn/terminal/module.go +++ b/spn/terminal/module.go @@ -1,17 +1,31 @@ package terminal import ( + "errors" "flag" + "sync/atomic" "time" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/base/rng" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/unit" ) +type TerminalModule struct { + mgr *mgr.Manager + instance instance +} + +func (s *TerminalModule) Start(m *mgr.Manager) error { + return start() +} + +func (s *TerminalModule) Stop(m *mgr.Manager) error { + return nil +} + var ( - module *modules.Module rngFeeder *rng.Feeder = rng.NewFeeder() scheduler *unit.Scheduler @@ -21,8 +35,6 @@ var ( func init() { flag.BoolVar(&debugUnitScheduling, "debug-unit-scheduling", false, "enable debug logs of the SPN unit scheduler") - - module = modules.Register("terminal", nil, start, nil, "base") } func start() error { @@ -33,7 +45,7 @@ func start() error { // Debug unit leaks. scheduler.StartDebugLog() } - module.StartServiceWorker("msg unit scheduler", 0, scheduler.SlotScheduler) + module.mgr.Go("msg unit scheduler", scheduler.SlotScheduler) lockOpRegistry() @@ -78,3 +90,22 @@ func getSchedulerConfig() *unit.SchedulerConfig { StatCycleDuration: 1 * time.Minute, // Match metrics report cycle. } } + +var ( + module *TerminalModule + shimLoaded atomic.Bool +) + +// New returns a new Config module. +func New(instance instance) (*TerminalModule, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &TerminalModule{ + instance: instance, + } + return module, nil +} + +type instance interface{} diff --git a/spn/terminal/operation.go b/spn/terminal/operation.go index 23249be05..2f58ce639 100644 --- a/spn/terminal/operation.go +++ b/spn/terminal/operation.go @@ -1,7 +1,6 @@ package terminal import ( - "context" "sync" "sync/atomic" "time" @@ -11,6 +10,7 @@ import ( "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/utils" + "github.com/safing/portmaster/service/mgr" ) // Operation is an interface for all operations. @@ -245,7 +245,7 @@ func (t *TerminalBase) StopOperation(op Operation, err *Error) { log.Warningf("spn/terminal: operation %s %s failed: %s", op.Type(), fmtOperationID(t.parentID, t.id, op.ID()), err) } - module.StartWorker("stop operation", func(_ context.Context) error { + module.mgr.Go("stop operation", func(_ *mgr.WorkerCtx) error { // Call operation stop handle function for proper shutdown cleaning up. err = op.HandleStop(err) diff --git a/spn/unit/scheduler.go b/spn/unit/scheduler.go index 0b5d6e113..1db0c5014 100644 --- a/spn/unit/scheduler.go +++ b/spn/unit/scheduler.go @@ -8,6 +8,7 @@ import ( "sync/atomic" "time" + "github.com/safing/portmaster/service/mgr" "github.com/tevino/abool" ) @@ -185,7 +186,7 @@ func (s *Scheduler) announceNextSlot() { // SlotScheduler manages the slot and schedules units. // Must only be started once. -func (s *Scheduler) SlotScheduler(ctx context.Context) error { +func (s *Scheduler) SlotScheduler(ctx *mgr.WorkerCtx) error { // Start slot ticker. ticker := time.NewTicker(s.config.SlotDuration / 2) defer ticker.Stop() From 4b4ff4a8d0ef1b0085d9d3a7fae0286930cce218 Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov Date: Mon, 17 Jun 2024 17:58:56 +0300 Subject: [PATCH 09/56] [WIP] switch all service modules --- base/database/interface_cache.go | 6 +- service/broadcasts/module.go | 47 +++++++-- service/broadcasts/notify.go | 4 +- service/compat/module.go | 89 +++++++++++------ service/compat/notify.go | 3 +- service/instance.go | 108 ++++++++++++++++++-- service/mgr/worker.go | 2 +- service/nameserver/module.go | 64 +++++++++--- service/netenv/main.go | 2 +- service/netquery/module_api.go | 158 +++++++++++++++++------------- service/network/clean.go | 8 +- service/network/dns.go | 3 +- service/network/module.go | 83 ++++++++++------ service/process/module.go | 48 +++++++-- service/profile/merge.go | 4 +- service/profile/module.go | 4 +- service/resolver/failing.go | 5 +- service/resolver/main.go | 113 ++++++++++++--------- service/resolver/metrics.go | 3 +- service/resolver/resolver-mdns.go | 25 ++--- service/sync/module.go | 46 +++++++-- service/ui/module.go | 6 +- spn/access/module.go | 12 +-- spn/captain/api.go | 1 - spn/captain/module.go | 21 ++-- spn/crew/module.go | 4 +- spn/navigator/module.go | 23 ++--- 27 files changed, 605 insertions(+), 287 deletions(-) diff --git a/base/database/interface_cache.go b/base/database/interface_cache.go index 06b16d174..b860a2a31 100644 --- a/base/database/interface_cache.go +++ b/base/database/interface_cache.go @@ -1,16 +1,16 @@ package database import ( - "context" "errors" "time" "github.com/safing/portmaster/base/database/record" "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" ) // DelayedCacheWriter must be run by the caller of an interface that uses delayed cache writing. -func (i *Interface) DelayedCacheWriter(ctx context.Context) error { +func (i *Interface) DelayedCacheWriter(wc *mgr.WorkerCtx) error { // Check if the DelayedCacheWriter should be run at all. if i.options.CacheSize <= 0 || i.options.DelayCachedWrites == "" { return errors.New("delayed cache writer is not applicable to this database interface") @@ -32,7 +32,7 @@ func (i *Interface) DelayedCacheWriter(ctx context.Context) error { for { // Wait for trigger for writing the cache. select { - case <-ctx.Done(): + case <-wc.Done(): // The caller is shutting down, flush the cache to storage and exit. i.flushWriteCache(0) return nil diff --git a/service/broadcasts/module.go b/service/broadcasts/module.go index 9741d7e14..1f6d3882b 100644 --- a/service/broadcasts/module.go +++ b/service/broadcasts/module.go @@ -1,16 +1,33 @@ package broadcasts import ( + "errors" "sync" + "sync/atomic" "time" "github.com/safing/portmaster/base/database" - "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/service/mgr" ) -var ( - module *modules.Module +type Broadcasts struct { + mgr *mgr.Manager + instance instance +} + +func (b *Broadcasts) Start(m *mgr.Manager) error { + b.mgr = m + if err := prep(); err != nil { + return err + } + return start() +} +func (b *Broadcasts) Stop(m *mgr.Manager) error { + return nil +} + +var ( db = database.NewInterface(&database.Options{ Local: true, Internal: true, @@ -20,7 +37,7 @@ var ( ) func init() { - module = modules.Register("broadcasts", prep, start, nil, "updates", "netenv", "notifications") + // module = modules.Register("broadcasts", prep, start, nil, "updates", "netenv", "notifications") } func prep() error { @@ -38,9 +55,27 @@ func start() error { // Start broadcast notifier task. startOnce.Do(func() { - module.NewTask("broadcast notifier", broadcastNotify). - Repeat(10 * time.Minute).Queue() + module.mgr.Repeat("broadcast notifier", 10*time.Minute, broadcastNotify) }) return nil } + +var ( + module *Broadcasts + shimLoaded atomic.Bool +) + +// New returns a new Config module. +func New(instance instance) (*Broadcasts, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &Broadcasts{ + instance: instance, + } + return module, nil +} + +type instance interface{} diff --git a/service/broadcasts/notify.go b/service/broadcasts/notify.go index 57a7830ad..c425a4677 100644 --- a/service/broadcasts/notify.go +++ b/service/broadcasts/notify.go @@ -16,8 +16,8 @@ import ( "github.com/safing/portmaster/base/database/accessor" "github.com/safing/portmaster/base/database/query" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/base/notifications" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/updates" ) @@ -66,7 +66,7 @@ type BroadcastNotification struct { repeatDuration time.Duration } -func broadcastNotify(ctx context.Context, t *modules.Task) error { +func broadcastNotify(ctx mgr.WorkerCtx) error { // Get broadcast notifications file, load it from disk and parse it. broadcastsResource, err := updates.GetFile(broadcastsResourcePath) if err != nil { diff --git a/service/compat/module.go b/service/compat/module.go index 0d9c3bd91..23a230e6a 100644 --- a/service/compat/module.go +++ b/service/compat/module.go @@ -1,22 +1,38 @@ package compat import ( - "context" "errors" + "sync/atomic" "time" "github.com/tevino/abool" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/resolver" ) -var ( - module *modules.Module +type Compat struct { + mgr *mgr.Manager + instance instance +} + +// Start starts the module. +func (u *Compat) Start(m *mgr.Manager) error { + u.mgr = m + if err := prep(); err != nil { + return err + } + return start() +} - selfcheckTask *modules.Task +// Stop stops the module. +func (u *Compat) Stop(_ *mgr.Manager) error { + return stop() +} + +var ( selfcheckTaskRetryAfter = 15 * time.Second // selfCheckIsFailing holds whether or not the self-check is currently @@ -38,7 +54,7 @@ var ( const selfcheckFailThreshold = 10 func init() { - module = modules.Register("compat", prep, start, stop, "base", "network", "interception", "netenv", "notifications") + // module = modules.Register("compat", prep, start, stop, "base", "network", "interception", "netenv", "notifications") // Workaround resolver integration. // See resolver/compat.go for details. @@ -55,35 +71,30 @@ func start() error { startNotify() selfcheckNetworkChangedFlag.Refresh() - selfcheckTask = module.NewTask("compatibility self-check", selfcheckTaskFunc). - Repeat(5 * time.Minute). - MaxDelay(selfcheckTaskRetryAfter). - Schedule(time.Now().Add(selfcheckTaskRetryAfter)) - - module.NewTask("clean notify thresholds", cleanNotifyThreshold). - Repeat(1 * time.Hour) - - return module.RegisterEventHook( - netenv.ModuleName, - netenv.NetworkChangedEvent, - "trigger compat self-check", - func(_ context.Context, _ interface{}) error { - selfcheckTask.Schedule(time.Now().Add(selfcheckTaskRetryAfter)) - return nil - }, - ) + module.mgr.Repeat("compatibility self-check", 5*time.Minute, selfcheckTaskFunc) + // selfcheckTask = module.NewTask("compatibility self-check", selfcheckTaskFunc). + // Repeat(5 * time.Minute). + // MaxDelay(selfcheckTaskRetryAfter). + // Schedule(time.Now().Add(selfcheckTaskRetryAfter)) + + module.mgr.Repeat("clean notify thresholds", 1*time.Hour, cleanNotifyThreshold) + module.instance.NetEnv().EventNetworkChange.AddCallback("trigger compat self-check", func(_ *mgr.WorkerCtx, _ struct{}) (bool, error) { + module.mgr.Delay("trigger compat self-check", selfcheckTaskRetryAfter, selfcheckTaskFunc) + return false, nil + }) + return nil } func stop() error { - selfcheckTask.Cancel() - selfcheckTask = nil + // selfcheckTask.Cancel() + // selfcheckTask = nil return nil } -func selfcheckTaskFunc(ctx context.Context, task *modules.Task) error { +func selfcheckTaskFunc(wc *mgr.WorkerCtx) error { // Create tracing logger. - ctx, tracer := log.AddTracer(ctx) + ctx, tracer := log.AddTracer(wc.Ctx()) defer tracer.Submit() tracer.Tracef("compat: running self-check") @@ -115,7 +126,8 @@ func selfcheckTaskFunc(ctx context.Context, task *modules.Task) error { } // Retry quicker when failed. - task.Schedule(time.Now().Add(selfcheckTaskRetryAfter)) + module.mgr.Delay("trigger compat self-check", selfcheckTaskRetryAfter, selfcheckTaskFunc) + // task.Schedule(time.Now().Add(selfcheckTaskRetryAfter)) return nil } @@ -135,3 +147,24 @@ func selfcheckTaskFunc(ctx context.Context, task *modules.Task) error { func SelfCheckIsFailing() bool { return selfCheckIsFailing.IsSet() } + +var ( + module *Compat + shimLoaded atomic.Bool +) + +// New returns a new Compat module. +func New(instance instance) (*Compat, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &Compat{ + instance: instance, + } + return module, nil +} + +type instance interface { + NetEnv() *netenv.NetEnv +} diff --git a/service/compat/notify.go b/service/compat/notify.go index 38db2168b..4f93d4195 100644 --- a/service/compat/notify.go +++ b/service/compat/notify.go @@ -12,6 +12,7 @@ import ( "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/base/notifications" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/process" "github.com/safing/portmaster/service/profile" ) @@ -273,7 +274,7 @@ func isOverThreshold(id string) bool { return false } -func cleanNotifyThreshold(ctx context.Context, task *modules.Task) error { +func cleanNotifyThreshold(ctx *mgr.WorkerCtx) error { notifyThresholdsLock.Lock() defer notifyThresholdsLock.Unlock() diff --git a/service/instance.go b/service/instance.go index 9a86401a5..5fd89a63b 100644 --- a/service/instance.go +++ b/service/instance.go @@ -9,11 +9,19 @@ import ( "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/base/rng" "github.com/safing/portmaster/base/runtime" + "github.com/safing/portmaster/service/broadcasts" + "github.com/safing/portmaster/service/compat" "github.com/safing/portmaster/service/firewall" "github.com/safing/portmaster/service/mgr" + "github.com/safing/portmaster/service/nameserver" "github.com/safing/portmaster/service/netenv" + "github.com/safing/portmaster/service/netquery" + "github.com/safing/portmaster/service/network" + "github.com/safing/portmaster/service/process" "github.com/safing/portmaster/service/profile" + "github.com/safing/portmaster/service/resolver" "github.com/safing/portmaster/service/status" + "github.com/safing/portmaster/service/sync" "github.com/safing/portmaster/service/ui" "github.com/safing/portmaster/service/updates" "github.com/safing/portmaster/spn/access" @@ -52,12 +60,20 @@ type Instance struct { sluice *sluice.SluiceModule terminal *terminal.TerminalModule - updates *updates.Updates - ui *ui.UI - profile *profile.ProfileModule - filter *firewall.Filter - netenv *netenv.NetEnv - status *status.Status + updates *updates.Updates + ui *ui.UI + profile *profile.ProfileModule + filter *firewall.Filter + netenv *netenv.NetEnv + status *status.Status + broadcasts *broadcasts.Broadcasts + compat *compat.Compat + nameserver *nameserver.NameServer + netquery *netquery.NetQuery + network *network.Network + process *process.ProcessModule + resolver *resolver.ResolverModule + sync *sync.Sync } // New returns a new portmaster service instance. @@ -162,6 +178,38 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { if err != nil { return nil, fmt.Errorf("create status module: %w", err) } + instance.broadcasts, err = broadcasts.New(instance) + if err != nil { + return nil, fmt.Errorf("create broadcasts module: %w", err) + } + instance.compat, err = compat.New(instance) + if err != nil { + return nil, fmt.Errorf("create compat module: %w", err) + } + instance.nameserver, err = nameserver.New(instance) + if err != nil { + return nil, fmt.Errorf("create nameserver module: %w", err) + } + instance.netquery, err = netquery.NewModule(instance) + if err != nil { + return nil, fmt.Errorf("create netquery module: %w", err) + } + instance.network, err = network.New(instance) + if err != nil { + return nil, fmt.Errorf("create network module: %w", err) + } + instance.process, err = process.New(instance) + if err != nil { + return nil, fmt.Errorf("create process module: %w", err) + } + instance.resolver, err = resolver.New(instance) + if err != nil { + return nil, fmt.Errorf("create resolver module: %w", err) + } + instance.sync, err = sync.New(instance) + if err != nil { + return nil, fmt.Errorf("create sync module: %w", err) + } // Add all modules to instance group. instance.Group = mgr.NewGroup( @@ -189,6 +237,14 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { instance.filter, instance.netenv, instance.status, + instance.broadcasts, + instance.compat, + instance.nameserver, + instance.netquery, + instance.network, + instance.process, + instance.resolver, + instance.sync, ) return instance, nil @@ -303,3 +359,43 @@ func (i *Instance) NetEnv() *netenv.NetEnv { func (i *Instance) Status() *status.Status { return i.status } + +// Broadcasts returns the broadcast module. +func (i *Instance) Broadcasts() *status.Status { + return i.status +} + +// Compat returns the compat module. +func (i *Instance) Compat() *compat.Compat { + return i.compat +} + +// NameServer returns the nameserver module. +func (i *Instance) NameServer() *nameserver.NameServer { + return i.nameserver +} + +// NetQuery returns the newquery module. +func (i *Instance) NetQuery() *netquery.NetQuery { + return i.netquery +} + +// Network returns the network module. +func (i *Instance) Network() *network.Network { + return i.network +} + +// Process returns the process module. +func (i *Instance) Process() *process.ProcessModule { + return i.process +} + +// Resolver returns the resolver module. +func (i *Instance) Resolver() *resolver.ResolverModule { + return i.resolver +} + +// Sync returns the sync module. +func (i *Instance) Sync() *sync.Sync { + return i.sync +} diff --git a/service/mgr/worker.go b/service/mgr/worker.go index 15464442f..95933660f 100644 --- a/service/mgr/worker.go +++ b/service/mgr/worker.go @@ -350,7 +350,7 @@ repeat: continue repeat case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): - // A canceled context or dexceeded eadline also means that the worker is finished. + // A canceled context or exceeded deadline also means that the worker is finished. continue repeat default: diff --git a/service/nameserver/module.go b/service/nameserver/module.go index 6dcd320dd..c8d0afd08 100644 --- a/service/nameserver/module.go +++ b/service/nameserver/module.go @@ -1,27 +1,42 @@ package nameserver import ( - "context" + "errors" "fmt" "net" "os" "strconv" "sync" + "sync/atomic" "github.com/miekg/dns" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" - "github.com/safing/portmaster/base/modules/subsystems" "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/service/compat" "github.com/safing/portmaster/service/firewall" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" ) -var ( - module *modules.Module +type NameServer struct { + mgr *mgr.Manager + instance instance +} +func (ns *NameServer) Start(m *mgr.Manager) error { + ns.mgr = m + if err := prep(); err != nil { + return err + } + return start() +} + +func (ns *NameServer) Stop(m *mgr.Manager) error { + return stop() +} + +var ( stopListeners bool stopListener1 func() error stopListener2 func() error @@ -32,15 +47,15 @@ var ( ) func init() { - module = modules.Register("nameserver", prep, start, stop, "core", "resolver") - subsystems.Register( - "dns", - "Secure DNS", - "DNS resolver with scoping and DNS-over-TLS", - module, - "config:dns/", - nil, - ) + // module = modules.Register("nameserver", prep, start, stop, "core", "resolver") + // subsystems.Register( + // "dns", + // "Secure DNS", + // "DNS resolver with scoping and DNS-over-TLS", + // module, + // "config:dns/", + // nil, + // ) } func prep() error { @@ -101,7 +116,7 @@ func start() error { func startListener(ip net.IP, port uint16, first bool) { // Start DNS server as service worker. - module.StartServiceWorker("dns resolver", 0, func(ctx context.Context) error { + module.mgr.Go("dns resolver", func(ctx *mgr.WorkerCtx) error { // Create DNS server. dnsServer := &dns.Server{ Addr: net.JoinHostPort( @@ -286,3 +301,22 @@ func getListenAddresses(listenAddress string) (ip1, ip2 net.IP, port uint16, err return ip1, ip2, uint16(port64), nil } + +var ( + module *NameServer + shimLoaded atomic.Bool +) + +// New returns a new NameServer module. +func New(instance instance) (*NameServer, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &NameServer{ + instance: instance, + } + return module, nil +} + +type instance interface{} diff --git a/service/netenv/main.go b/service/netenv/main.go index 8f29df8b5..35b231278 100644 --- a/service/netenv/main.go +++ b/service/netenv/main.go @@ -84,7 +84,7 @@ var ( shimLoaded atomic.Bool ) -// New returns a new UI module. +// New returns a new NetEnv module. func New(instance instance) (*NetEnv, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") diff --git a/service/netquery/module_api.go b/service/netquery/module_api.go index 5d58709b8..2a4a2d008 100644 --- a/service/netquery/module_api.go +++ b/service/netquery/module_api.go @@ -3,7 +3,9 @@ package netquery import ( "context" "encoding/json" + "errors" "fmt" + "sync/atomic" "time" "github.com/hashicorp/go-multierror" @@ -14,17 +16,15 @@ import ( "github.com/safing/portmaster/base/database" "github.com/safing/portmaster/base/database/query" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" - "github.com/safing/portmaster/base/modules/subsystems" "github.com/safing/portmaster/base/runtime" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/network" + "github.com/safing/portmaster/service/profile" ) -// DefaultModule is the default netquery module. -var DefaultModule *module - -type module struct { - *modules.Module +type NetQuery struct { + mgr *mgr.Manager + instance instance Store *Database @@ -33,66 +33,67 @@ type module struct { feed chan *network.Connection } +// DefaultModule is the default netquery module. func init() { - DefaultModule = new(module) - - DefaultModule.Module = modules.Register( - "netquery", - DefaultModule.prepare, - DefaultModule.start, - DefaultModule.stop, - "api", - "network", - "database", - ) - - subsystems.Register( - "history", - "Network History", - "Keep Network History Data", - DefaultModule.Module, - "config:history/", - nil, - ) + // DefaultModule = new(module) + + // DefaultModule.Module = modules.Register( + // "netquery", + // DefaultModule.prepare, + // DefaultModule.start, + // DefaultModule.stop, + // "api", + // "network", + // "database", + // ) + + // subsystems.Register( + // "history", + // "Network History", + // "Keep Network History Data", + // DefaultModule.Module, + // "config:history/", + // nil, + // ) } -func (m *module) prepare() error { +func (nq *NetQuery) prepare() error { var err error - m.db = database.NewInterface(&database.Options{ + nq.db = database.NewInterface(&database.Options{ Local: true, Internal: true, }) // TODO: Open database in start() phase. - m.Store, err = NewInMemory() + nq.Store, err = NewInMemory() if err != nil { return fmt.Errorf("failed to create in-memory database: %w", err) } - m.mng, err = NewManager(m.Store, "netquery/data/", runtime.DefaultRegistry) + nq.mng, err = NewManager(nq.Store, "netquery/data/", runtime.DefaultRegistry) if err != nil { return fmt.Errorf("failed to create manager: %w", err) } - m.feed = make(chan *network.Connection, 1000) + nq.feed = make(chan *network.Connection, 1000) queryHander := &QueryHandler{ - Database: m.Store, + Database: nq.Store, IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false), } batchHander := &BatchQueryHandler{ - Database: m.Store, + Database: nq.Store, IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false), } chartHandler := &ActiveChartHandler{ - Database: m.Store, + Database: nq.Store, } bwChartHandler := &BandwidthChartHandler{ - Database: m.Store, + Database: nq.Store, } if err := api.RegisterEndpoint(api.Endpoint{ @@ -162,13 +163,13 @@ func (m *module) prepare() error { } if len(body.ProfileIDs) == 0 { - if err := m.mng.store.RemoveAllHistoryData(ar.Context()); err != nil { + if err := nq.mng.store.RemoveAllHistoryData(ar.Context()); err != nil { return "", fmt.Errorf("failed to remove all history: %w", err) } } else { merr := new(multierror.Error) for _, profileID := range body.ProfileIDs { - if err := m.mng.store.RemoveHistoryForProfile(ar.Context(), profileID); err != nil { + if err := nq.mng.store.RemoveHistoryForProfile(ar.Context(), profileID); err != nil { merr.Errors = append(merr.Errors, fmt.Errorf("failed to clear history for %q: %w", profileID, err)) } else { log.Infof("netquery: successfully cleared history for %s", profileID) @@ -192,7 +193,7 @@ func (m *module) prepare() error { Write: api.PermitUser, BelongsTo: m.Module, ActionFunc: func(ar *api.Request) (msg string, err error) { - if err := m.Store.CleanupHistory(ar.Context()); err != nil { + if err := nq.Store.CleanupHistory(ar.Context()); err != nil { return "", err } return "Deleted outdated connections.", nil @@ -204,13 +205,18 @@ func (m *module) prepare() error { return nil } -func (m *module) start() error { - m.StartServiceWorker("netquery connection feed listener", 0, func(ctx context.Context) error { - sub, err := m.db.Subscribe(query.New("network:")) +func (nq *NetQuery) Start(m *mgr.Manager) error { + nq.mgr = m + if err := nq.prepare(); err != nil { + return fmt.Errorf("failed to prepare netquery module: %w", err) + } + + nq.mgr.Go("netquery connection feed listener", func(ctx *mgr.WorkerCtx) error { + sub, err := nq.db.Subscribe(query.New("network:")) if err != nil { return fmt.Errorf("failed to subscribe to network tree: %w", err) } - defer close(m.feed) + defer close(nq.feed) defer func() { _ = sub.Cancel() }() @@ -231,24 +237,24 @@ func (m *module) start() error { continue } - m.feed <- conn + nq.feed <- conn } } }) - m.StartServiceWorker("netquery connection feed handler", 0, func(ctx context.Context) error { - m.mng.HandleFeed(ctx, m.feed) + nq.mgr.Go("netquery connection feed handler", func(ctx *mgr.WorkerCtx) error { + nq.mng.HandleFeed(ctx.Ctx(), nq.feed) return nil }) - m.StartServiceWorker("netquery live db cleaner", 0, func(ctx context.Context) error { + nq.mgr.Go("netquery live db cleaner", func(ctx *mgr.WorkerCtx) error { for { select { case <-ctx.Done(): return nil case <-time.After(10 * time.Second): threshold := time.Now().Add(-network.DeleteConnsAfterEndedThreshold) - count, err := m.Store.Cleanup(ctx, threshold) + count, err := nq.Store.Cleanup(ctx.Ctx(), threshold) if err != nil { log.Errorf("netquery: failed to removed old connections from live db: %s", err) } else { @@ -258,51 +264,48 @@ func (m *module) start() error { } }) - m.NewTask("network history cleaner", func(ctx context.Context, _ *modules.Task) error { - return m.Store.CleanupHistory(ctx) - }).Repeat(time.Hour).Schedule(time.Now().Add(10 * time.Minute)) + nq.mgr.Delay("network history cleaner delay", 10*time.Minute, func(_ *mgr.WorkerCtx) error { + nq.mgr.Repeat("network history cleaner delay", 1*time.Hour, func(w *mgr.WorkerCtx) error { + return nq.Store.CleanupHistory(w.Ctx()) + }) + return nil + }) // For debugging, provide a simple direct SQL query interface using // the runtime database. // Only expose in development mode. if config.GetAsBool(config.CfgDevModeKey, false)() { - _, err := NewRuntimeQueryRunner(m.Store, "netquery/query/", runtime.DefaultRegistry) + _, err := NewRuntimeQueryRunner(nq.Store, "netquery/query/", runtime.DefaultRegistry) if err != nil { return fmt.Errorf("failed to set up runtime SQL query runner: %w", err) } } // Migrate profile IDs in history database when profiles are migrated/merged. - if err := m.RegisterEventHook( - "profiles", - "profile migrated", - "migrate profile IDs in history database", - func(ctx context.Context, data interface{}) error { - if profileIDs, ok := data.([]string); ok && len(profileIDs) == 2 { - return m.Store.MigrateProfileID(ctx, profileIDs[0], profileIDs[1]) + nq.instance.Profile().EventMigrated.AddCallback("migrate profile IDs in history database", + func(ctx *mgr.WorkerCtx, profileIDs []string) (bool, error) { + if len(profileIDs) == 2 { + return false, nq.Store.MigrateProfileID(ctx.Ctx(), profileIDs[0], profileIDs[1]) } - return nil - }, - ); err != nil { - return err - } + return false, nil + }) return nil } -func (m *module) stop() error { +func (nq *NetQuery) Stop(m *mgr.Manager) error { // we don't use m.Module.Ctx here because it is already cancelled when stop is called. // just give the clean up 1 minute to happen and abort otherwise. ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - if err := m.mng.store.MarkAllHistoryConnectionsEnded(ctx); err != nil { + if err := nq.mng.store.MarkAllHistoryConnectionsEnded(ctx); err != nil { // handle the error by just logging it. There's not much we can do here // and returning an error to the module system doesn't help much as well... log.Errorf("netquery: failed to mark connections in history database as ended: %s", err) } - if err := m.mng.store.Close(); err != nil { + if err := nq.mng.store.Close(); err != nil { log.Errorf("netquery: failed to close sqlite database: %s", err) } else { // Clear deleted connections from database. @@ -313,3 +316,24 @@ func (m *module) stop() error { return nil } + +var ( + module *NetQuery + shimLoaded atomic.Bool +) + +// New returns a new NetQuery module. +func NewModule(instance instance) (*NetQuery, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &NetQuery{ + instance: instance, + } + return module, nil +} + +type instance interface { + Profile() *profile.ProfileModule +} diff --git a/service/network/clean.go b/service/network/clean.go index 3b04990b9..62fe04587 100644 --- a/service/network/clean.go +++ b/service/network/clean.go @@ -1,10 +1,10 @@ package network import ( - "context" "time" "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/network/packet" "github.com/safing/portmaster/service/network/state" "github.com/safing/portmaster/service/process" @@ -31,7 +31,7 @@ const ( cleanerTickDuration = 5 * time.Second ) -func connectionCleaner(ctx context.Context) error { +func connectionCleaner(ctx *mgr.WorkerCtx) error { ticker := module.NewSleepyTicker(cleanerTickDuration, 0) for { @@ -45,7 +45,7 @@ func connectionCleaner(ctx context.Context) error { process.CleanProcessStorage(activePIDs) // clean udp connection states - state.CleanUDPStates(ctx) + state.CleanUDPStates(ctx.Ctx()) } } } @@ -53,7 +53,7 @@ func connectionCleaner(ctx context.Context) error { func cleanConnections() (activePIDs map[int]struct{}) { activePIDs = make(map[int]struct{}) - _ = module.RunMicroTask("clean connections", 0, func(ctx context.Context) error { + module.mgr.Go("clean connections", func(ctx *mgr.WorkerCtx) error { now := time.Now().UTC() nowUnix := now.Unix() ignoreNewer := nowUnix - 2 diff --git a/service/network/dns.go b/service/network/dns.go index 9b0dbe94e..c7e02b586 100644 --- a/service/network/dns.go +++ b/service/network/dns.go @@ -11,6 +11,7 @@ import ( "golang.org/x/exp/slices" "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/nameserver/nsutil" "github.com/safing/portmaster/service/network/packet" "github.com/safing/portmaster/service/process" @@ -173,7 +174,7 @@ func SaveOpenDNSRequest(q *resolver.Query, rrCache *resolver.RRCache, conn *Conn openDNSRequests[key] = conn } -func openDNSRequestWriter(ctx context.Context) error { +func openDNSRequestWriter(ctx *mgr.WorkerCtx) error { ticker := module.NewSleepyTicker(writeOpenDNSRequestsTickDuration, 0) defer ticker.Stop() diff --git a/service/network/module.go b/service/network/module.go index 2863cf99f..a14e3a7e9 100644 --- a/service/network/module.go +++ b/service/network/module.go @@ -2,33 +2,47 @@ package network import ( "context" + "errors" "fmt" "strings" "sync" + "sync/atomic" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/network/state" "github.com/safing/portmaster/service/profile" ) -var ( - module *modules.Module - - defaultFirewallHandler FirewallHandler -) - // Events. -var ( +const ( ConnectionReattributedEvent = "connection re-attributed" ) -func init() { - module = modules.Register("network", prep, start, nil, "base", "netenv", "processes") - module.RegisterEvent(ConnectionReattributedEvent, false) +type Network struct { + mgr *mgr.Manager + instance instance + + EventConnectionReattributed *mgr.EventMgr[string] +} + +func (n *Network) Start(m *mgr.Manager) error { + n.mgr = m + n.EventConnectionReattributed = mgr.NewEventMgr[string](ConnectionReattributedEvent, m) + + if err := prep(); err != nil { + return err + } + return start() +} + +func (n *Network) Stop(mgr *mgr.Manager) error { + return nil } +var defaultFirewallHandler FirewallHandler + // SetDefaultFirewallHandler sets the default firewall handler. func SetDefaultFirewallHandler(handler FirewallHandler) { if defaultFirewallHandler == nil { @@ -55,17 +69,9 @@ func start() error { return err } - module.StartServiceWorker("clean connections", 0, connectionCleaner) - module.StartServiceWorker("write open dns requests", 0, openDNSRequestWriter) - - if err := module.RegisterEventHook( - "profiles", - profile.DeletedEvent, - "re-attribute connections from deleted profile", - reAttributeConnections, - ); err != nil { - return err - } + module.mgr.Go("clean connections", connectionCleaner) + module.mgr.Go("write open dns requests", openDNSRequestWriter) + module.instance.Profile().EventDelete.AddCallback("re-attribute connections from deleted profile", reAttributeConnections) return nil } @@ -74,14 +80,10 @@ var reAttributionLock sync.Mutex // reAttributeConnections finds all connections of a deleted profile and re-attributes them. // Expected event data: scoped profile ID. -func reAttributeConnections(_ context.Context, eventData any) error { - profileID, ok := eventData.(string) - if !ok { - return fmt.Errorf("event data is not a string: %v", eventData) - } +func reAttributeConnections(_ *mgr.WorkerCtx, profileID string) (bool, error) { profileSource, profileID, ok := strings.Cut(profileID, "/") if !ok { - return fmt.Errorf("event data does not seem to be a scoped profile ID: %v", eventData) + return false, fmt.Errorf("event data does not seem to be a scoped profile ID: %v", profileID) } // Hold a lock for re-attribution, to prevent simultaneous processing of the @@ -114,7 +116,7 @@ func reAttributeConnections(_ context.Context, eventData any) error { } tracer.Infof("filter: re-attributed %d connections", reAttributed) - return nil + return false, nil } func reAttributeConnection(ctx context.Context, conn *Connection, profileID, profileSource string) (reAttributed bool) { @@ -144,8 +146,29 @@ func reAttributeConnection(ctx context.Context, conn *Connection, profileID, pro conn.Save() // Trigger event for re-attribution. - module.TriggerEvent(ConnectionReattributedEvent, conn.ID) + module.EventConnectionReattributed.Submit(conn.ID) log.Tracer(ctx).Debugf("filter: re-attributed %s to %s", conn, conn.process.PrimaryProfileID) return true } + +var ( + module *Network + shimLoaded atomic.Bool +) + +// New returns a new Network module. +func New(instance instance) (*Network, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &Network{ + instance: instance, + } + return module, nil +} + +type instance interface { + Profile() *profile.ProfileModule +} diff --git a/service/process/module.go b/service/process/module.go index be97b26ea..ddc2808e1 100644 --- a/service/process/module.go +++ b/service/process/module.go @@ -1,21 +1,32 @@ package process import ( + "errors" "os" + "sync/atomic" - "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/updates" ) -var ( - module *modules.Module - updatesPath string -) +type ProcessModule struct { + instance instance +} + +func (pm *ProcessModule) Start(m *mgr.Manager) error { + if err := prep(); err != nil { + return err + } -func init() { - module = modules.Register("processes", prep, start, nil, "profiles", "updates") + return start() } +func (pm *ProcessModule) Stop(m *mgr.Manager) error { + return nil +} + +var updatesPath string + func prep() error { return registerConfiguration() } @@ -32,3 +43,26 @@ func start() error { return nil } + +var ( + module *ProcessModule + shimLoaded atomic.Bool +) + +// New returns a new Process module. +func New(instance instance) (*ProcessModule, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + if err := prep(); err != nil { + return nil, err + } + + module = &ProcessModule{ + instance: instance, + } + return module, nil +} + +type instance interface{} diff --git a/service/profile/merge.go b/service/profile/merge.go index cc3755dbf..d2c3f1c29 100644 --- a/service/profile/merge.go +++ b/service/profile/merge.go @@ -77,12 +77,12 @@ func MergeProfiles(name string, primary *Profile, secondaries ...*Profile) (newP if err := primary.delete(); err != nil { return nil, fmt.Errorf("failed to delete primary profile %s: %w", primary.ScopedID(), err) } - module.TriggerEvent(MigratedEvent, []string{primary.ScopedID(), newProfile.ScopedID()}) + module.EventMigrated.Submit([]string{primary.ScopedID(), newProfile.ScopedID()}) for _, sp := range secondaries { if err := sp.delete(); err != nil { return nil, fmt.Errorf("failed to delete secondary profile %s: %w", sp.ScopedID(), err) } - module.TriggerEvent(MigratedEvent, []string{sp.ScopedID(), newProfile.ScopedID()}) + module.EventMigrated.Submit([]string{sp.ScopedID(), newProfile.ScopedID()}) } return newProfile, nil diff --git a/service/profile/module.go b/service/profile/module.go index 03532fad2..162b63cec 100644 --- a/service/profile/module.go +++ b/service/profile/module.go @@ -36,7 +36,7 @@ type ProfileModule struct { EventConfigChange *mgr.EventMgr[string] EventDelete *mgr.EventMgr[string] - EventMigrated *mgr.EventMgr[string] + EventMigrated *mgr.EventMgr[[]string] } func (pm *ProfileModule) Start(m *mgr.Manager) error { @@ -44,7 +44,7 @@ func (pm *ProfileModule) Start(m *mgr.Manager) error { pm.EventConfigChange = mgr.NewEventMgr[string](ConfigChangeEvent, m) pm.EventDelete = mgr.NewEventMgr[string](DeletedEvent, m) - pm.EventMigrated = mgr.NewEventMgr[string](MigratedEvent, m) + pm.EventMigrated = mgr.NewEventMgr[[]string](MigratedEvent, m) if err := prep(); err != nil { return err diff --git a/service/resolver/failing.go b/service/resolver/failing.go index c1e347b5d..c8e011bd0 100644 --- a/service/resolver/failing.go +++ b/service/resolver/failing.go @@ -6,6 +6,7 @@ import ( "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" ) @@ -71,7 +72,7 @@ func (brc *BasicResolverConn) ResetFailure() { } } -func checkFailingResolvers(ctx context.Context, task *modules.Task) error { +func checkFailingResolvers(wc *mgr.WorkerCtx) error { //, task *modules.Task) error { var resolvers []*Resolver // Make a copy of the resolver list. @@ -84,7 +85,7 @@ func checkFailingResolvers(ctx context.Context, task *modules.Task) error { }() // Start logging. - ctx, tracer := log.AddTracer(ctx) + ctx, tracer := log.AddTracer(wc.Ctx()) tracer.Debugf("resolver: checking failed resolvers") defer tracer.Submit() diff --git a/service/resolver/main.go b/service/resolver/main.go index f50dcc226..6d50dcc89 100644 --- a/service/resolver/main.go +++ b/service/resolver/main.go @@ -2,27 +2,43 @@ package resolver import ( "context" + "errors" "fmt" "net" "strings" "sync" + "sync/atomic" "time" "github.com/tevino/abool" + "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/base/utils/debug" _ "github.com/safing/portmaster/service/core/base" "github.com/safing/portmaster/service/intel" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" + "github.com/safing/portmaster/spn/captain" ) -var module *modules.Module +type ResolverModule struct { + mgr *mgr.Manager + instance instance +} -func init() { - module = modules.Register("resolver", prep, start, nil, "base", "netenv") +func (rm *ResolverModule) Start(m *mgr.Manager) error { + rm.mgr = m + if err := prep(); err != nil { + return err + } + return start() +} + +func (rm *ResolverModule) Stop(m *mgr.Manager) error { + return nil } func prep() error { @@ -49,41 +65,28 @@ func start() error { loadResolvers() // reload after network change - err := module.RegisterEventHook( - "netenv", - "network changed", + module.instance.NetEnv().EventNetworkChange.AddCallback( "update nameservers", - func(_ context.Context, _ interface{}) error { + func(_ *mgr.WorkerCtx, _ struct{}) (bool, error) { loadResolvers() log.Debug("resolver: reloaded nameservers due to network change") - return nil + return false, nil }, ) - if err != nil { - return err - } // Force resolvers to reconnect when SPN has connected. - if err := module.RegisterEventHook( - "captain", - "spn connect", // Defined by captain.SPNConnectedEvent + module.instance.Captain().EventSPNConnected.AddCallback( "force resolver reconnect", - func(ctx context.Context, _ any) error { - ForceResolverReconnect(ctx) - return nil - }, - ); err != nil { - // This module does not depend on the SPN/Captain module, and probably should not. - log.Warningf("resolvers: failed to register event hook for captain/spn-connect: %s", err) - } + func(ctx *mgr.WorkerCtx, _ struct{}) (bool, error) { + ForceResolverReconnect(ctx.Ctx()) + return false, nil + }) // reload after config change prevNameservers := strings.Join(configuredNameServers(), " ") - err = module.RegisterEventHook( - "config", - "config change", + module.instance.Config().EventConfigChange.AddCallback( "update nameservers", - func(_ context.Context, _ interface{}) error { + func(_ *mgr.WorkerCtx, _ struct{}) (bool, error) { newNameservers := strings.Join(configuredNameServers(), " ") if newNameservers != prevNameservers { prevNameservers = newNameservers @@ -91,38 +94,27 @@ func start() error { loadResolvers() log.Debug("resolver: reloaded nameservers due to config change") } - return nil - }, - ) - if err != nil { - return err - } + return false, nil + }) // Check failing resolvers regularly and when the network changes. - checkFailingResolversTask := module.NewTask("check failing resolvers", checkFailingResolvers).Repeat(1 * time.Minute) - err = module.RegisterEventHook( - "netenv", - netenv.NetworkChangedEvent, + module.mgr.Repeat("check failing resolvers", 1*time.Minute, checkFailingResolvers) + module.instance.NetEnv().EventNetworkChange.AddCallback( "check failing resolvers", - func(_ context.Context, _ any) error { - checkFailingResolversTask.StartASAP() - return nil - }, - ) - if err != nil { - return err - } + func(wc *mgr.WorkerCtx, _ struct{}) (bool, error) { + checkFailingResolvers(wc) + return false, nil + }) - module.NewTask("suggest using stale cache", suggestUsingStaleCacheTask).Repeat(2 * time.Minute) + module.mgr.Repeat("suggest using stale cache", 2*time.Minute, suggestUsingStaleCacheTask) - module.StartServiceWorker( + module.mgr.Go( "mdns handler", - 5*time.Second, listenToMDNS, ) - module.StartServiceWorker("name record delayed cache writer", 0, recordDatabase.DelayedCacheWriter) - module.StartServiceWorker("ip info delayed cache writer", 0, ipInfoDatabase.DelayedCacheWriter) + module.mgr.Go("name record delayed cache writer", recordDatabase.DelayedCacheWriter) + module.mgr.Go("ip info delayed cache writer", ipInfoDatabase.DelayedCacheWriter) return nil } @@ -247,3 +239,26 @@ func AddToDebugInfo(di *debug.Info) { content..., ) } + +var ( + module *ResolverModule + shimLoaded atomic.Bool +) + +// New returns a new Resolver module. +func New(instance instance) (*ResolverModule, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &ResolverModule{ + instance: instance, + } + return module, nil +} + +type instance interface { + NetEnv() *netenv.NetEnv + Captain() *captain.Captain + Config() *config.Config +} diff --git a/service/resolver/metrics.go b/service/resolver/metrics.go index 02ce9897b..c2589ee8f 100644 --- a/service/resolver/metrics.go +++ b/service/resolver/metrics.go @@ -8,6 +8,7 @@ import ( "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/base/notifications" + "github.com/safing/portmaster/service/mgr" ) var ( @@ -49,7 +50,7 @@ func resetSlowQueriesSensorValue() { var suggestUsingStaleCacheNotification *notifications.Notification -func suggestUsingStaleCacheTask(ctx context.Context, t *modules.Task) error { +func suggestUsingStaleCacheTask(_ *mgr.WorkerCtx) error { // t *modules.Task) error { switch { case useStaleCache() || useStaleCacheConfigOption.IsSetByUser(): // If setting is already active, disable task repeating. diff --git a/service/resolver/resolver-mdns.go b/service/resolver/resolver-mdns.go index cb4ba62c4..0eb674f98 100644 --- a/service/resolver/resolver-mdns.go +++ b/service/resolver/resolver-mdns.go @@ -12,6 +12,7 @@ import ( "github.com/miekg/dns" "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/network/netutils" ) @@ -74,7 +75,7 @@ func indexOfRR(entry *dns.RR_Header, list *[]dns.RR) int { } //nolint:gocyclo,gocognit // TODO: make simpler -func listenToMDNS(ctx context.Context) error { +func listenToMDNS(wc *mgr.WorkerCtx) error { var err error messages := make(chan *dns.Msg, 32) @@ -86,8 +87,8 @@ func listenToMDNS(ctx context.Context) error { // TODO: retry after some time log.Warningf("intel(mdns): failed to create udp4 listen multicast socket: %s", err) } else { - module.StartServiceWorker("mdns udp4 multicast listener", 0, func(ctx context.Context) error { - return listenForDNSPackets(ctx, multicast4Conn, messages) + module.mgr.Go("mdns udp4 multicast listener", func(wc *mgr.WorkerCtx) error { + return listenForDNSPackets(wc.Ctx(), multicast4Conn, messages) }) defer func() { _ = multicast4Conn.Close() @@ -99,8 +100,8 @@ func listenToMDNS(ctx context.Context) error { // TODO: retry after some time log.Warningf("intel(mdns): failed to create udp4 listen socket: %s", err) } else { - module.StartServiceWorker("mdns udp4 unicast listener", 0, func(ctx context.Context) error { - return listenForDNSPackets(ctx, unicast4Conn, messages) + module.mgr.Go("mdns udp4 unicast listener", func(wc *mgr.WorkerCtx) error { + return listenForDNSPackets(wc.Ctx(), unicast4Conn, messages) }) defer func() { _ = unicast4Conn.Close() @@ -113,8 +114,8 @@ func listenToMDNS(ctx context.Context) error { // TODO: retry after some time log.Warningf("intel(mdns): failed to create udp6 listen multicast socket: %s", err) } else { - module.StartServiceWorker("mdns udp6 multicast listener", 0, func(ctx context.Context) error { - return listenForDNSPackets(ctx, multicast6Conn, messages) + module.mgr.Go("mdns udp6 multicast listener", func(wc *mgr.WorkerCtx) error { + return listenForDNSPackets(wc.Ctx(), multicast6Conn, messages) }) defer func() { _ = multicast6Conn.Close() @@ -126,8 +127,8 @@ func listenToMDNS(ctx context.Context) error { // TODO: retry after some time log.Warningf("intel(mdns): failed to create udp6 listen socket: %s", err) } else { - module.StartServiceWorker("mdns udp6 unicast listener", 0, func(ctx context.Context) error { - return listenForDNSPackets(ctx, unicast6Conn, messages) + module.mgr.Go("mdns udp6 unicast listener", func(wc *mgr.WorkerCtx) error { + return listenForDNSPackets(wc.Ctx(), unicast6Conn, messages) }) defer func() { _ = unicast6Conn.Close() @@ -138,12 +139,12 @@ func listenToMDNS(ctx context.Context) error { } // start message handler - module.StartServiceWorker("mdns message handler", 0, func(ctx context.Context) error { - return handleMDNSMessages(ctx, messages) + module.mgr.Go("mdns message handler", func(wc *mgr.WorkerCtx) error { + return handleMDNSMessages(wc.Ctx(), messages) }) // wait for shutdown - <-module.Ctx.Done() + <-wc.Done() return nil } diff --git a/service/sync/module.go b/service/sync/module.go index e6d43142f..3d7a8a8ed 100644 --- a/service/sync/module.go +++ b/service/sync/module.go @@ -1,23 +1,30 @@ package sync import ( + "errors" + "sync/atomic" + "github.com/safing/portmaster/base/database" - "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/service/mgr" ) -var ( - module *modules.Module +type Sync struct { + instance instance +} - db = database.NewInterface(&database.Options{ - Local: true, - Internal: true, - }) -) +func (s *Sync) Start(m *mgr.Manager) error { + return prep() +} -func init() { - module = modules.Register("sync", prep, nil, nil, "profiles") +func (s *Sync) Stop(m *mgr.Manager) error { + return nil } +var db = database.NewInterface(&database.Options{ + Local: true, + Internal: true, +}) + func prep() error { if err := registerSettingsAPI(); err != nil { return err @@ -30,3 +37,22 @@ func prep() error { } return nil } + +var ( + module *Sync + shimLoaded atomic.Bool +) + +// New returns a new NetEnv module. +func New(instance instance) (*Sync, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &Sync{ + instance: instance, + } + return module, nil +} + +type instance interface{} diff --git a/service/ui/module.go b/service/ui/module.go index 7a1691258..54ef93838 100644 --- a/service/ui/module.go +++ b/service/ui/module.go @@ -50,12 +50,10 @@ func (ui *UI) Start(m *mgr.Manager) error { // Stop stops the module. func (ui *UI) Stop(_ *mgr.Manager) error { - return stop() + return nil } -var ( - shimLoaded atomic.Bool -) +var shimLoaded atomic.Bool // New returns a new UI module. func New(instance instance) (*UI, error) { diff --git a/spn/access/module.go b/spn/access/module.go index 3b6333a26..2ff60cdba 100644 --- a/spn/access/module.go +++ b/spn/access/module.go @@ -83,10 +83,10 @@ func start() error { loadTokens() // Register new task. - accountUpdateTask = module.mgr.Go( - "update account", - UpdateAccount, - ).Repeat(24 * time.Hour).Schedule(time.Now().Add(1 * time.Minute)) + module.mgr.Delay("update account delayed", 1*time.Minute, func(_ *mgr.WorkerCtx) error { + module.mgr.Repeat("update account", 24*time.Hour, UpdateAccount) + return nil + }) } return nil @@ -95,8 +95,8 @@ func start() error { func stop() error { if conf.Client() { // Stop account update task. - accountUpdateTask.Cancel() - accountUpdateTask = nil + // accountUpdateTask.Cancel() + // accountUpdateTask = nil // Store tokens to database. storeTokens() diff --git a/spn/captain/api.go b/spn/captain/api.go index ec43bc44e..0b37828b7 100644 --- a/spn/captain/api.go +++ b/spn/captain/api.go @@ -7,7 +7,6 @@ import ( "github.com/safing/portmaster/base/api" "github.com/safing/portmaster/base/database" "github.com/safing/portmaster/base/database/query" - "github.com/safing/portmaster/base/modules" ) const ( diff --git a/spn/captain/module.go b/spn/captain/module.go index 55299f292..20876a22b 100644 --- a/spn/captain/module.go +++ b/spn/captain/module.go @@ -1,7 +1,6 @@ package captain import ( - "context" "errors" "fmt" "net" @@ -95,17 +94,13 @@ func prep() error { return err } - if err := module.RegisterEventHook( - "patrol", - patrol.ChangeSignalEventName, + module.instance.Patrol().EventChangeSignal.AddCallback( "trigger hub status maintenance", - func(_ context.Context, _ any) error { + func(_ *mgr.WorkerCtx, _ struct{}) (bool, error) { TriggerHubStatusMaintenance() - return nil + return false, nil }, - ); err != nil { - return err - } + ) } return prepConfig() @@ -172,9 +167,10 @@ func start() error { // network optimizer if conf.PublicHub() { - module.mgr.Go("optimize network", optimizeNetwork). - Repeat(1 * time.Minute). - Schedule(time.Now().Add(15 * time.Second)) + module.mgr.Delay("optimize network delay", 15*time.Second, func(_ *mgr.WorkerCtx) error { + module.mgr.Repeat("optimize network", 1*time.Minute, optimizeNetwork) + return nil + }) } // client + home hub manager @@ -250,4 +246,5 @@ func New(instance instance) (*Captain, error) { type instance interface { NetEnv() *netenv.NetEnv + Patrol() *patrol.Patrol } diff --git a/spn/crew/module.go b/spn/crew/module.go index 2a677435d..f96cee59c 100644 --- a/spn/crew/module.go +++ b/spn/crew/module.go @@ -24,9 +24,7 @@ func (c *Crew) Stop(m *mgr.Manager) error { } func start() error { - module.NewTask("sticky cleaner", cleanStickyHubs). - Repeat(10 * time.Minute) - + module.mgr.Repeat("sticky cleaner", 10*time.Minute, cleanStickyHubs) return registerMetrics() } diff --git a/spn/navigator/module.go b/spn/navigator/module.go index c8027fb90..774dc670f 100644 --- a/spn/navigator/module.go +++ b/spn/navigator/module.go @@ -110,20 +110,21 @@ geoInitCheck: } // TODO: delete superseded hubs after x amount of time - - module.NewTask("update states", Main.updateStates). - Repeat(1 * time.Hour). - Schedule(time.Now().Add(3 * time.Minute)) - - module.NewTask("update failing states", Main.updateFailingStates). - Repeat(1 * time.Minute). - Schedule(time.Now().Add(3 * time.Minute)) + module.mgr.Delay("update states delay", 3*time.Minute, func(w *mgr.WorkerCtx) error { + module.mgr.Repeat("update states", 1*time.Hour, Main.updateStates) + return nil + }) + module.mgr.Delay("update failing states delay", 3*time.Minute, func(w *mgr.WorkerCtx) error { + module.mgr.Repeat("update states", 1*time.Minute, Main.updateFailingStates) + return nil + }) if conf.PublicHub() { // Only measure Hubs on public Hubs. - module.NewTask("measure hubs", Main.measureHubs). - Repeat(5 * time.Minute). - Schedule(time.Now().Add(1 * time.Minute)) + module.mgr.Delay("measure hubs delay", 5*time.Minute, func(w *mgr.WorkerCtx) error { + module.mgr.Repeat("measure hubs", 1*time.Minute, Main.updateFailingStates) + return nil + }) // Only register metrics on Hubs, as they only make sense there. err := registerMetrics() From 49e98fef981ee84056c7a18e498ce3cf26151b9b Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov Date: Thu, 20 Jun 2024 14:41:56 +0300 Subject: [PATCH 10/56] [WIP] Convert all workers to the new module system --- base/api/authentication.go | 8 +- base/api/database.go | 10 +- base/api/endpoints.go | 11 - base/api/endpoints_debug.go | 15 +- base/api/endpoints_modules.go | 93 ++++---- base/api/main.go | 12 +- base/api/main_test.go | 85 ++++--- base/api/module.go | 16 +- base/api/modules.go | 64 +++-- base/api/router.go | 27 ++- base/database/dbmodule/maintenance.go | 12 +- base/metrics/api.go | 8 +- base/metrics/module.go | 17 +- base/notifications/module-mirror.go | 218 +++++++++--------- base/notifications/notification.go | 16 +- base/rng/test/main.go | 2 - base/runtime/modules_integration.go | 8 +- base/template/module_test.go | 1 - base/utils/debug/debug.go | 24 +- cmds/hub/main.go | 2 - cmds/notifier/main.go | 1 - cmds/observation-hub/apprise.go | 1 - cmds/observation-hub/main.go | 2 - cmds/observation-hub/observe.go | 1 - cmds/portmaster-core/main.go | 29 ++- go.mod | 38 +-- go.sum | 40 ++++ service/broadcasts/api.go | 3 - service/broadcasts/notify.go | 9 +- service/compat/api.go | 1 - service/compat/notify.go | 9 +- service/compat/selfcheck.go | 3 +- service/core/api.go | 7 +- service/core/base/global.go | 9 +- service/core/base/logs.go | 12 +- service/core/base/module.go | 50 ++-- service/core/base/profiling.go | 7 +- service/core/core.go | 84 ++++--- service/core/pmtesting/testing.go | 1 - .../interception/interception_linux.go | 10 +- service/firewall/interception/module.go | 38 ++- .../firewall/interception/nfqueue_linux.go | 4 +- service/firewall/module.go | 60 ++--- service/firewall/packet_handler.go | 5 +- service/instance.go | 107 +++++++-- service/intel/customlists/lists.go | 22 +- service/intel/customlists/module.go | 81 ++++--- service/intel/filterlists/module.go | 100 +++++--- service/intel/filterlists/updater.go | 49 ++-- service/intel/geoip/database.go | 6 +- service/intel/geoip/module.go | 52 +++-- service/intel/module.go | 15 +- service/mgr/sleepyticker.go | 58 +++++ service/nameserver/module.go | 20 +- service/nameserver/nameserver.go | 5 +- service/netenv/api.go | 20 +- service/netquery/module_api.go | 12 +- service/network/api.go | 6 +- service/network/clean.go | 6 +- service/network/connection.go | 9 +- service/network/database.go | 4 +- service/network/dns.go | 6 +- service/network/module.go | 12 + service/process/api.go | 3 - service/profile/active.go | 1 - service/profile/api.go | 3 - service/profile/config-update.go | 41 ++-- service/profile/get.go | 6 +- service/profile/migrations.go | 23 +- service/profile/module.go | 12 +- service/resolver/api.go | 7 +- service/resolver/failing.go | 8 +- service/resolver/main.go | 20 +- service/resolver/metrics.go | 10 +- service/resolver/resolve.go | 7 +- service/resolver/resolver-mdns.go | 2 +- service/resolver/resolver-tcp.go | 31 +-- service/resolver/resolvers.go | 29 +-- service/sync/profile.go | 4 +- service/sync/setting_single.go | 4 +- service/sync/settings.go | 4 +- service/updates/main.go | 47 ++-- service/updates/notify.go | 2 +- spn/access/client.go | 19 +- spn/access/module.go | 21 +- spn/access/storage.go | 2 +- spn/access/token/module_test.go | 1 - spn/access/zones.go | 11 +- spn/captain/api.go | 67 +++--- spn/captain/bootstrap.go | 2 +- spn/captain/client.go | 56 +++-- spn/captain/hooks.go | 4 +- spn/captain/intel.go | 24 +- spn/captain/module.go | 61 +++-- spn/captain/navigation.go | 3 +- spn/captain/op_gossip.go | 2 +- spn/captain/op_gossip_query.go | 7 +- spn/captain/op_publish.go | 2 +- spn/captain/piers.go | 13 +- spn/captain/public.go | 42 +--- spn/crew/sticky.go | 5 +- spn/docks/controller.go | 2 +- spn/docks/crane.go | 6 +- spn/docks/crane_establish.go | 4 +- spn/docks/crane_init.go | 12 +- spn/docks/crane_terminal.go | 2 +- spn/docks/hub_import.go | 2 +- spn/docks/op_capacity.go | 14 +- spn/docks/op_expand.go | 15 +- spn/docks/op_latency.go | 6 +- spn/docks/op_sync_state.go | 5 +- spn/docks/terminal_expansion.go | 4 +- spn/hub/hub_test.go | 1 - spn/navigator/api.go | 7 - spn/navigator/api_route.go | 1 - spn/navigator/database.go | 10 +- spn/navigator/measurements.go | 3 +- spn/navigator/update.go | 7 +- spn/patrol/http.go | 10 +- spn/patrol/module.go | 5 +- spn/ships/http_shared.go | 9 +- spn/ships/module.go | 2 + spn/ships/tcp.go | 8 +- spn/sluice/packet_listener.go | 10 +- spn/sluice/sluice.go | 9 +- spn/sluice/udp_listener.go | 10 +- spn/terminal/control_flow.go | 7 +- spn/terminal/operation_counter.go | 8 +- spn/terminal/terminal.go | 12 +- spn/terminal/testing.go | 64 ++--- spn/unit/scheduler.go | 1 - 131 files changed, 1426 insertions(+), 1099 deletions(-) create mode 100644 service/mgr/sleepyticker.go diff --git a/base/api/authentication.go b/base/api/authentication.go index cdd724702..64ae3538a 100644 --- a/base/api/authentication.go +++ b/base/api/authentication.go @@ -122,7 +122,7 @@ type AuthenticatedHandler interface { // SetAuthenticator sets an authenticator function for the API endpoint. If none is set, all requests will be permitted. func SetAuthenticator(fn AuthenticatorFunc) error { - if module.Online() { + if module.online { return ErrAuthenticationImmutable } @@ -351,7 +351,7 @@ func checkAPIKey(r *http.Request) *AuthToken { return token } -func updateAPIKeys(_ *mgr.WorkerCtx, _ struct{}) (bool, error) { +func updateAPIKeys(_ context.Context) error { apiKeysLock.Lock() defer apiKeysLock.Unlock() @@ -444,7 +444,7 @@ func updateAPIKeys(_ *mgr.WorkerCtx, _ struct{}) (bool, error) { }) } - return false, nil + return nil } func checkSessionCookie(r *http.Request) *AuthToken { @@ -507,7 +507,7 @@ func createSession(w http.ResponseWriter, r *http.Request, token *AuthToken) err return nil } -func cleanSessions(_ context.Context, _ *modules.Task) error { +func cleanSessions(_ *mgr.WorkerCtx) error { sessionsLock.Lock() defer sessionsLock.Unlock() diff --git a/base/api/database.go b/base/api/database.go index e7bdb7073..ee1e7eae1 100644 --- a/base/api/database.go +++ b/base/api/database.go @@ -2,7 +2,6 @@ package api import ( "bytes" - "context" "errors" "fmt" "net/http" @@ -21,6 +20,7 @@ import ( "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" ) const ( @@ -122,13 +122,13 @@ func startDatabaseWebsocketAPI(w http.ResponseWriter, r *http.Request) { newDBAPI.sendQueue <- data } - module.StartWorker("database api handler", newDBAPI.handler) - module.StartWorker("database api writer", newDBAPI.writer) + module.mgr.Go("database api handler", newDBAPI.handler) + module.mgr.Go("database api writer", newDBAPI.writer) log.Tracer(r.Context()).Infof("api request: init websocket %s %s", r.RemoteAddr, r.RequestURI) } -func (api *DatabaseWebsocketAPI) handler(context.Context) error { +func (api *DatabaseWebsocketAPI) handler(_ *mgr.WorkerCtx) error { defer func() { _ = api.shutdown(nil) }() @@ -143,7 +143,7 @@ func (api *DatabaseWebsocketAPI) handler(context.Context) error { } } -func (api *DatabaseWebsocketAPI) writer(ctx context.Context) error { +func (api *DatabaseWebsocketAPI) writer(ctx *mgr.WorkerCtx) error { defer func() { _ = api.shutdown(nil) }() diff --git a/base/api/endpoints.go b/base/api/endpoints.go index 8f8769b88..7ee14a4c0 100644 --- a/base/api/endpoints.go +++ b/base/api/endpoints.go @@ -16,7 +16,6 @@ import ( "github.com/safing/portmaster/base/database/record" "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" ) // Endpoint describes an API Endpoint. @@ -62,10 +61,6 @@ type Endpoint struct { //nolint:maligned // access if the write method does not match. WriteMethod string `json:",omitempty"` - // BelongsTo defines which module this endpoint belongs to. - // The endpoint will not be accessible if the module is not online. - BelongsTo *modules.Module `json:"-"` - // ActionFunc is for simple actions with a return message for the user. ActionFunc ActionFunc `json:"-"` @@ -379,12 +374,6 @@ func (e *Endpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // Wait for the owning module to be ready. - if !moduleIsReady(e.BelongsTo) { - http.Error(w, "The API endpoint is not ready yet or the its module is not enabled. Reload (F5) to try again.", http.StatusServiceUnavailable) - return - } - // Return OPTIONS request before starting to handle normal requests. if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) diff --git a/base/api/endpoints_debug.go b/base/api/endpoints_debug.go index dc1ba6ff5..da1397b6a 100644 --- a/base/api/endpoints_debug.go +++ b/base/api/endpoints_debug.go @@ -3,7 +3,6 @@ package api import ( "bytes" "context" - "errors" "fmt" "net/http" "os" @@ -12,7 +11,6 @@ import ( "time" "github.com/safing/portmaster/base/info" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/base/utils/debug" ) @@ -131,18 +129,19 @@ You can easily view this data in your browser with this command (with Go install // ping responds with pong. func ping(ar *Request) (msg string, err error) { // TODO: Remove upgrade to "ready" when all UI components have transitioned. - if modules.IsStarting() || modules.IsShuttingDown() { - return "", ErrorWithStatus(errors.New("portmaster is not ready, reload (F5) to try again"), http.StatusTooEarly) - } + // if modules.IsStarting() || modules.IsShuttingDown() { + // return "", ErrorWithStatus(errors.New("portmaster is not ready, reload (F5) to try again"), http.StatusTooEarly) + // } return "Pong.", nil } // ready checks if Portmaster has completed starting. func ready(ar *Request) (msg string, err error) { - if modules.IsStarting() || modules.IsShuttingDown() { - return "", ErrorWithStatus(errors.New("portmaster is not ready, reload (F5) to try again"), http.StatusTooEarly) - } + // TODO(vladimir): provide alternative for this. Instance state? + // if modules.IsStarting() || modules.IsShuttingDown() { + // return "", ErrorWithStatus(errors.New("portmaster is not ready, reload (F5) to try again"), http.StatusTooEarly) + // } return "Portmaster is ready.", nil } diff --git a/base/api/endpoints_modules.go b/base/api/endpoints_modules.go index 22f6af3a8..7be727956 100644 --- a/base/api/endpoints_modules.go +++ b/base/api/endpoints_modules.go @@ -1,56 +1,51 @@ package api -import ( - "errors" - "fmt" - - "github.com/safing/portmaster/base/modules" -) - func registerModulesEndpoints() error { - if err := RegisterEndpoint(Endpoint{ - Path: "modules/status", - Read: PermitUser, - StructFunc: getStatusfunc, - Name: "Get Module Status", - Description: "Returns status information of all modules.", - }); err != nil { - return err - } - - if err := RegisterEndpoint(Endpoint{ - Path: "modules/{moduleName:.+}/trigger/{eventName:.+}", - Write: PermitSelf, - ActionFunc: triggerEvent, - Name: "Trigger Event", - Description: "Triggers an event of an internal module.", - }); err != nil { - return err - } + // TODO(vladimir): do we need this? + // if err := RegisterEndpoint(Endpoint{ + // Path: "modules/status", + // Read: PermitUser, + // StructFunc: getStatusfunc, + // Name: "Get Module Status", + // Description: "Returns status information of all modules.", + // }); err != nil { + // return err + // } + + // TODO(vladimir): do we need this? + // if err := RegisterEndpoint(Endpoint{ + // Path: "modules/{moduleName:.+}/trigger/{eventName:.+}", + // Write: PermitSelf, + // ActionFunc: triggerEvent, + // Name: "Trigger Event", + // Description: "Triggers an event of an internal module.", + // }); err != nil { + // return err + // } return nil } -func getStatusfunc(ar *Request) (i interface{}, err error) { - status := modules.GetStatus() - if status == nil { - return nil, errors.New("modules not yet initialized") - } - return status, nil -} - -func triggerEvent(ar *Request) (msg string, err error) { - // Get parameters. - moduleName := ar.URLVars["moduleName"] - eventName := ar.URLVars["eventName"] - if moduleName == "" || eventName == "" { - return "", errors.New("invalid parameters") - } - - // Inject event. - if err := module.InjectEvent("api event injection", moduleName, eventName, nil); err != nil { - return "", fmt.Errorf("failed to inject event: %w", err) - } - - return "event successfully injected", nil -} +// func getStatusfunc(ar *Request) (i interface{}, err error) { +// status := modules.GetStatus() +// if status == nil { +// return nil, errors.New("modules not yet initialized") +// } +// return status, nil +// } + +// func triggerEvent(ar *Request) (msg string, err error) { +// // Get parameters. +// moduleName := ar.URLVars["moduleName"] +// eventName := ar.URLVars["eventName"] +// if moduleName == "" || eventName == "" { +// return "", errors.New("invalid parameters") +// } + +// // Inject event. +// if err := module.InjectEvent("api event injection", moduleName, eventName, nil); err != nil { +// return "", fmt.Errorf("failed to inject event: %w", err) +// } + +// return "event successfully injected", nil +// } diff --git a/base/api/main.go b/base/api/main.go index 7693e52d1..26b79350d 100644 --- a/base/api/main.go +++ b/base/api/main.go @@ -6,6 +6,9 @@ import ( "flag" "os" "time" + + "github.com/safing/portbase/modules" + "github.com/safing/portmaster/service/mgr" ) var exportEndpoints bool @@ -53,12 +56,15 @@ func prep() error { func start() error { startServer() - _ = updateAPIKeys(module.mgr.Ctx(), nil) - module.instance.Config().EventConfigChange.AddCallback("update API keys", updateAPIKeys) + _ = updateAPIKeys(module.mgr.Ctx()) + module.instance.Config().EventConfigChange.AddCallback("update API keys", + func(wc *mgr.WorkerCtx, s struct{}) (cancel bool, err error) { + return false, updateAPIKeys(wc.Ctx()) + }) // start api auth token cleaner if authFnSet.IsSet() { - module.NewTask("clean api sessions", cleanSessions).Repeat(5 * time.Minute) + module.mgr.Repeat("clean api sessions", 5*time.Minute, cleanSessions) } return registerEndpointBridgeDB() diff --git a/base/api/main_test.go b/base/api/main_test.go index df06dc631..6f68ad241 100644 --- a/base/api/main_test.go +++ b/base/api/main_test.go @@ -1,56 +1,55 @@ package api import ( - "fmt" - "os" - "testing" +// "fmt" +// "os" +// "testing" - // API depends on the database for the database api. - _ "github.com/safing/portmaster/base/database/dbmodule" - "github.com/safing/portmaster/base/dataroot" - "github.com/safing/portmaster/base/modules" +// API depends on the database for the database api. +// _ "github.com/safing/portmaster/base/database/dbmodule" +// "github.com/safing/portmaster/base/dataroot" ) func init() { defaultListenAddress = "127.0.0.1:8817" } -func TestMain(m *testing.M) { - // enable module for testing - module.Enable() +// func TestMain(m *testing.M) { +// // enable module for testing +// module.Enable() - // tmp dir for data root (db & config) - tmpDir, err := os.MkdirTemp("", "portbase-testing-") - if err != nil { - fmt.Fprintf(os.Stderr, "failed to create tmp dir: %s\n", err) - os.Exit(1) - } - // initialize data dir - err = dataroot.Initialize(tmpDir, 0o0755) - if err != nil { - fmt.Fprintf(os.Stderr, "failed to initialize data root: %s\n", err) - os.Exit(1) - } +// // tmp dir for data root (db & config) +// tmpDir, err := os.MkdirTemp("", "portbase-testing-") +// if err != nil { +// fmt.Fprintf(os.Stderr, "failed to create tmp dir: %s\n", err) +// os.Exit(1) +// } +// // initialize data dir +// err = dataroot.Initialize(tmpDir, 0o0755) +// if err != nil { +// fmt.Fprintf(os.Stderr, "failed to initialize data root: %s\n", err) +// os.Exit(1) +// } - // start modules - var exitCode int - err = modules.Start() - if err != nil { - // starting failed - fmt.Fprintf(os.Stderr, "failed to setup test: %s\n", err) - exitCode = 1 - } else { - // run tests - exitCode = m.Run() - } +// // start modules +// var exitCode int +// err = modules.Start() +// if err != nil { +// // starting failed +// fmt.Fprintf(os.Stderr, "failed to setup test: %s\n", err) +// exitCode = 1 +// } else { +// // run tests +// exitCode = m.Run() +// } - // shutdown - _ = modules.Shutdown() - if modules.GetExitStatusCode() != 0 { - exitCode = modules.GetExitStatusCode() - fmt.Fprintf(os.Stderr, "failed to cleanly shutdown test: %s\n", err) - } - // clean up and exit - _ = os.RemoveAll(tmpDir) - os.Exit(exitCode) -} +// // shutdown +// _ = modules.Shutdown() +// if modules.GetExitStatusCode() != 0 { +// exitCode = modules.GetExitStatusCode() +// fmt.Fprintf(os.Stderr, "failed to cleanly shutdown test: %s\n", err) +// } +// // clean up and exit +// _ = os.RemoveAll(tmpDir) +// os.Exit(exitCode) +// } diff --git a/base/api/module.go b/base/api/module.go index 740e8d4d4..af5ea3a1f 100644 --- a/base/api/module.go +++ b/base/api/module.go @@ -12,12 +12,22 @@ import ( type API struct { mgr *mgr.Manager instance instance + + online bool } // Start starts the module. func (api *API) Start(m *mgr.Manager) error { api.mgr = m - return start() + if err := prep(); err != nil { + return err + } + if err := start(); err != nil { + return err + } + + module.online = true + return nil } // Stop stops the module. @@ -36,10 +46,6 @@ func New(instance instance) (*API, error) { return nil, errors.New("only one instance allowed") } - if err := prep(); err != nil { - return nil, err - } - module = &API{ instance: instance, } diff --git a/base/api/modules.go b/base/api/modules.go index c5c366db8..aa9190e86 100644 --- a/base/api/modules.go +++ b/base/api/modules.go @@ -2,14 +2,12 @@ package api import ( "time" - - "github.com/safing/portmaster/base/modules" ) // ModuleHandler specifies the interface for API endpoints that are bound to a module. -type ModuleHandler interface { - BelongsTo() *modules.Module -} +// type ModuleHandler interface { +// BelongsTo() *modules.Module +// } const ( moduleCheckMaxWait = 10 * time.Second @@ -19,31 +17,31 @@ const ( // moduleIsReady checks if the given module is online and http requests can be // sent its way. If the module is not online already, it will wait for a short // duration for it to come online. -func moduleIsReady(m *modules.Module) (ok bool) { - // Check if we are given a module. - if m == nil { - // If no module is given, we assume that the handler has not been linked to - // a module, and we can safely continue with the request. - return true - } - - // Check if the module is online. - if m.Online() { - return true - } - - // Check if the module will come online. - if m.OnlineSoon() { - var i time.Duration - for i = 0; i < moduleCheckMaxWait; i += moduleCheckTickDuration { - // Wait a little. - time.Sleep(moduleCheckTickDuration) - // Check if module is now online. - if m.Online() { - return true - } - } - } - - return false -} +// func moduleIsReady(m *modules.Module) (ok bool) { +// // Check if we are given a module. +// if m == nil { +// // If no module is given, we assume that the handler has not been linked to +// // a module, and we can safely continue with the request. +// return true +// } + +// // Check if the module is online. +// if m.Online() { +// return true +// } + +// // Check if the module will come online. +// if m.OnlineSoon() { +// var i time.Duration +// for i = 0; i < moduleCheckMaxWait; i += moduleCheckTickDuration { +// // Wait a little. +// time.Sleep(moduleCheckTickDuration) +// // Check if module is now online. +// if m.Online() { +// return true +// } +// } +// } + +// return false +// } diff --git a/base/api/router.go b/base/api/router.go index d8c3c3d20..4ec77758d 100644 --- a/base/api/router.go +++ b/base/api/router.go @@ -16,6 +16,7 @@ import ( "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/utils" + "github.com/safing/portmaster/service/mgr" ) // EnableServer defines if the HTTP server should be started. @@ -65,7 +66,7 @@ func startServer() { } // Start server manager. - module.StartServiceWorker("http server manager", 0, serverManager) + module.mgr.Go("http server manager", serverManager) } func stopServer() error { @@ -82,13 +83,13 @@ func stopServer() error { } // Serve starts serving the API endpoint. -func serverManager(_ context.Context) error { +func serverManager(_ *mgr.WorkerCtx) error { // start serving log.Infof("api: starting to listen on %s", server.Addr) backoffDuration := 10 * time.Second for { // always returns an error - err := module.RunWorker("http endpoint", func(ctx context.Context) error { + err := module.mgr.Do("http endpoint", func(ctx *mgr.WorkerCtx) error { return server.ListenAndServe() }) // return on shutdown error @@ -106,7 +107,7 @@ type mainHandler struct { } func (mh *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - _ = module.RunWorker("http request", func(_ context.Context) error { + _ = module.mgr.Do("http request", func(_ *mgr.WorkerCtx) error { return mh.handle(w, r) }) } @@ -269,12 +270,13 @@ func (mh *mainHandler) handle(w http.ResponseWriter, r *http.Request) error { } // Wait for the owning module to be ready. - if moduleHandler, ok := handler.(ModuleHandler); ok { - if !moduleIsReady(moduleHandler.BelongsTo()) { - http.Error(lrw, "The API endpoint is not ready yet. Reload (F5) to try again.", http.StatusServiceUnavailable) - return nil - } - } + // TODO(vladimir): no need to check for status anymore right? + // if moduleHandler, ok := handler.(ModuleHandler); ok { + // if !moduleIsReady(moduleHandler.BelongsTo()) { + // http.Error(lrw, "The API endpoint is not ready yet. Reload (F5) to try again.", http.StatusServiceUnavailable) + // return nil + // } + // } // Check if we have a handler. if handler == nil { @@ -286,8 +288,9 @@ func (mh *mainHandler) handle(w http.ResponseWriter, r *http.Request) error { defer func() { if panicValue := recover(); panicValue != nil { // Report failure via module system. - me := module.NewPanicError("api request", "custom", panicValue) - me.Report() + // TODO(vladimir): do we need panic report here + // me := module.NewPanicError("api request", "custom", panicValue) + // me.Report() // Respond with a server error. if devMode() { http.Error( diff --git a/base/database/dbmodule/maintenance.go b/base/database/dbmodule/maintenance.go index 22fddac33..64704460a 100644 --- a/base/database/dbmodule/maintenance.go +++ b/base/database/dbmodule/maintenance.go @@ -9,22 +9,22 @@ import ( ) func startMaintenanceTasks() { - module.mgr.Go("basic maintenance", maintainBasic).Repeat(10 * time.Minute).MaxDelay(10 * time.Minute) - module.mgr.Go("thorough maintenance", maintainThorough).Repeat(1 * time.Hour).MaxDelay(1 * time.Hour) - module.mgr.Go("record maintenance", maintainRecords).Repeat(1 * time.Hour).MaxDelay(1 * time.Hour) + module.mgr.Repeat("basic maintenance", 10*time.Minute, maintainBasic) + module.mgr.Repeat("thorough maintenance", 1*time.Hour, maintainThorough) + module.mgr.Repeat("record maintenance", 1*time.Hour, maintainRecords) } -func maintainBasic(ctx mgr.WorkerCtx) error { +func maintainBasic(ctx *mgr.WorkerCtx) error { log.Infof("database: running Maintain") return database.Maintain(ctx.Ctx()) } -func maintainThorough(ctx mgr.WorkerCtx) error { +func maintainThorough(ctx *mgr.WorkerCtx) error { log.Infof("database: running MaintainThorough") return database.MaintainThorough(ctx.Ctx()) } -func maintainRecords(ctx mgr.WorkerCtx) error { +func maintainRecords(ctx *mgr.WorkerCtx) error { log.Infof("database: running MaintainRecordStates") return database.MaintainRecordStates(ctx.Ctx()) } diff --git a/base/metrics/api.go b/base/metrics/api.go index 10a7d0d85..ccc3bfe90 100644 --- a/base/metrics/api.go +++ b/base/metrics/api.go @@ -22,7 +22,6 @@ func registerAPI() error { Description: "List all registered metrics with their metadata.", Path: "metrics/list", Read: api.Dynamic, - BelongsTo: module, StructFunc: func(ar *api.Request) (any, error) { return ExportMetrics(ar.AuthToken.Read), nil }, @@ -40,7 +39,6 @@ func registerAPI() error { Field: "internal-only", Description: "Specify to only return metrics with an alternative internal ID.", }}, - BelongsTo: module, StructFunc: func(ar *api.Request) (any, error) { return ExportValues( ar.AuthToken.Read, @@ -142,14 +140,14 @@ func writeMetricsTo(ctx context.Context, url string) error { func metricsWriter(ctx *mgr.WorkerCtx) error { pushURL := pushOption() - ticker := module.NewSleepyTicker(1*time.Minute, 0) - defer ticker.Stop() + module.metricTicker = mgr.NewSleepyTicker(1*time.Minute, 0) + defer module.metricTicker.Stop() for { select { case <-ctx.Done(): return nil - case <-ticker.Wait(): + case <-module.metricTicker.Wait(): err := writeMetricsTo(ctx.Ctx(), pushURL) if err != nil { return err diff --git a/base/metrics/module.go b/base/metrics/module.go index 1d4eb2ce8..a950c56e1 100644 --- a/base/metrics/module.go +++ b/base/metrics/module.go @@ -7,17 +7,19 @@ import ( "sync" "sync/atomic" - "github.com/safing/portbase/modules" "github.com/safing/portmaster/service/mgr" ) type Metrics struct { mgr *mgr.Manager instance instance + + metricTicker *mgr.SleepyTicker } func (met *Metrics) Start(m *mgr.Manager) error { met.mgr = m + if err := prepConfig(); err != nil { return err } @@ -28,6 +30,12 @@ func (met *Metrics) Stop(m *mgr.Manager) error { return stop() } +func (met *Metrics) SetSleep(enabled bool) { + if met.metricTicker != nil { + met.metricTicker.SetSleep(enabled) + } +} + var ( module *Metrics shimLoaded atomic.Bool @@ -128,9 +136,10 @@ func register(m Metric) error { // Set flag that first metric is now registered. firstMetricRegistered = true - if module.Status() < modules.StatusStarting { - return fmt.Errorf("registering metric %q too early", m.ID()) - } + // TODO(vladimir): With the new modules system there is no way this can fail. I may be wrong. + // if module.Status() < modules.StatusStarting { + // return fmt.Errorf("registering metric %q too early", m.ID()) + // } return nil } diff --git a/base/notifications/module-mirror.go b/base/notifications/module-mirror.go index 43df67d90..e1614fa0d 100644 --- a/base/notifications/module-mirror.go +++ b/base/notifications/module-mirror.go @@ -1,116 +1,116 @@ package notifications import ( - "github.com/safing/portbase/modules" - "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/service/mgr" +// "github.com/safing/portbase/modules" +// "github.com/safing/portmaster/base/log" +// "github.com/safing/portmaster/service/mgr" ) // AttachToModule attaches the notification to a module and changes to the // notification will be reflected on the module failure status. -func (n *Notification) AttachToModule(m mgr.Module) { - if m == nil { - log.Warningf("notifications: invalid usage: cannot attach %s to nil module", n.EventID) - return - } - - n.lock.Lock() - defer n.lock.Unlock() - - if n.State != Active { - log.Warningf("notifications: cannot attach module to inactive notification %s", n.EventID) - return - } - if n.belongsTo != nil { - log.Warningf("notifications: cannot override attached module for notification %s", n.EventID) - return - } - - // Attach module. - n.belongsTo = m - - // Set module failure status. - switch n.Type { //nolint:exhaustive - case Info: - m.Hint(n.EventID, n.Title, n.Message) - case Warning: - m.Warning(n.EventID, n.Title, n.Message) - case Error: - m.Error(n.EventID, n.Title, n.Message) - default: - log.Warningf("notifications: incompatible type for attaching to module in notification %s", n.EventID) - m.Error(n.EventID, n.Title, n.Message+" [incompatible notification type]") - } -} - -// resolveModuleFailure removes the notification from the module failure status. -func (n *Notification) resolveModuleFailure() { - if n.belongsTo != nil { - // Resolve failure in attached module. - n.belongsTo.Resolve(n.EventID) - - // Reset attachment in order to mitigate duplicate failure resolving. - // Re-attachment is prevented by the state check when attaching. - n.belongsTo = nil - } -} - -func init() { - modules.SetFailureUpdateNotifyFunc(mirrorModuleStatus) -} - -func mirrorModuleStatus(moduleFailure uint8, id, title, msg string) { - // Ignore "resolve all" requests. - if id == "" { - return - } - - // Get notification from storage. - n, ok := getNotification(id) - if ok { - // The notification already exists. - - // Check if we should delete it. - if moduleFailure == modules.FailureNone && !n.Meta().IsDeleted() { - - // Remove belongsTo, as the deletion was already triggered by the module itself. - n.Lock() - n.belongsTo = nil - n.Unlock() - - n.Delete() - } - - return - } - - // A notification for the given ID does not yet exists, create it. - n = &Notification{ - EventID: id, - Title: title, - Message: msg, - AvailableActions: []*Action{ - { - Text: "Get Help", - Type: ActionTypeOpenURL, - Payload: "https://safing.io/support/", - }, - }, - } - - switch moduleFailure { - case modules.FailureNone: - return - case modules.FailureHint: - n.Type = Info - n.AvailableActions = nil - case modules.FailureWarning: - n.Type = Warning - n.ShowOnSystem = true - case modules.FailureError: - n.Type = Error - n.ShowOnSystem = true - } - - Notify(n) -} +// func (n *Notification) AttachToState(state *mgr.StateMgr) { +// if state == nil { +// log.Warningf("notifications: invalid usage: cannot attach %s to nil module", n.EventID) +// return +// } + +// n.lock.Lock() +// defer n.lock.Unlock() + +// if n.State != Active { +// log.Warningf("notifications: cannot attach module to inactive notification %s", n.EventID) +// return +// } +// if n.belongsTo != nil { +// log.Warningf("notifications: cannot override attached module for notification %s", n.EventID) +// return +// } + +// // Attach module. +// n.belongsTo = state + +// // Set module failure status. +// switch n.Type { //nolint:exhaustive +// case Info: +// m.Hint(n.EventID, n.Title, n.Message) +// case Warning: +// m.Warning(n.EventID, n.Title, n.Message) +// case Error: +// m.Error(n.EventID, n.Title, n.Message) +// default: +// log.Warningf("notifications: incompatible type for attaching to module in notification %s", n.EventID) +// m.Error(n.EventID, n.Title, n.Message+" [incompatible notification type]") +// } +// } + +// // resolveModuleFailure removes the notification from the module failure status. +// func (n *Notification) resolveModuleFailure() { +// if n.belongsTo != nil { +// // Resolve failure in attached module. +// n.belongsTo.Resolve(n.EventID) + +// // Reset attachment in order to mitigate duplicate failure resolving. +// // Re-attachment is prevented by the state check when attaching. +// n.belongsTo = nil +// } +// } + +// func init() { +// modules.SetFailureUpdateNotifyFunc(mirrorModuleStatus) +// } + +// func mirrorModuleStatus(moduleFailure uint8, id, title, msg string) { +// // Ignore "resolve all" requests. +// if id == "" { +// return +// } + +// // Get notification from storage. +// n, ok := getNotification(id) +// if ok { +// // The notification already exists. + +// // Check if we should delete it. +// if moduleFailure == modules.FailureNone && !n.Meta().IsDeleted() { + +// // Remove belongsTo, as the deletion was already triggered by the module itself. +// n.Lock() +// n.belongsTo = nil +// n.Unlock() + +// n.Delete() +// } + +// return +// } + +// // A notification for the given ID does not yet exists, create it. +// n = &Notification{ +// EventID: id, +// Title: title, +// Message: msg, +// AvailableActions: []*Action{ +// { +// Text: "Get Help", +// Type: ActionTypeOpenURL, +// Payload: "https://safing.io/support/", +// }, +// }, +// } + +// switch moduleFailure { +// case modules.FailureNone: +// return +// case modules.FailureHint: +// n.Type = Info +// n.AvailableActions = nil +// case modules.FailureWarning: +// n.Type = Warning +// n.ShowOnSystem = true +// case modules.FailureError: +// n.Type = Error +// n.ShowOnSystem = true +// } + +// Notify(n) +// } diff --git a/base/notifications/notification.go b/base/notifications/notification.go index c088f78d7..68526ab17 100644 --- a/base/notifications/notification.go +++ b/base/notifications/notification.go @@ -8,8 +8,8 @@ import ( "github.com/safing/portmaster/base/database/record" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/base/utils" + "github.com/safing/portmaster/service/mgr" ) // Type describes the type of a notification. @@ -99,9 +99,9 @@ type Notification struct { //nolint:maligned // based on the user selection. SelectedActionID string - // belongsTo holds the module this notification belongs to. The notification - // lifecycle will be mirrored to the module's failure status. - belongsTo *modules.Module + // belongsTo holds the state this notification belongs to. The notification + // lifecycle will be mirrored to the specified failure status. + // belongsTo *mgr.StateMgr lock sync.Mutex actionFunction NotificationActionFn // call function to process action @@ -442,7 +442,7 @@ func (n *Notification) delete(pushUpdate bool) { dbController.PushUpdate(n) } - n.resolveModuleFailure() + // n.resolveModuleFailure() } // Expired notifies the caller when the notification has expired. @@ -468,8 +468,8 @@ func (n *Notification) selectAndExecuteAction(id string) { executed := false if n.actionFunction != nil { - module.StartWorker("notification action execution", func(ctx context.Context) error { - return n.actionFunction(ctx, n) + module.mgr.Go("notification action execution", func(ctx *mgr.WorkerCtx) error { + return n.actionFunction(ctx.Ctx(), n) }) executed = true } @@ -495,7 +495,7 @@ func (n *Notification) selectAndExecuteAction(id string) { if executed { n.State = Executed - n.resolveModuleFailure() + // n.resolveModuleFailure() } } diff --git a/base/rng/test/main.go b/base/rng/test/main.go index bc8883eeb..68ad0cbe2 100644 --- a/base/rng/test/main.go +++ b/base/rng/test/main.go @@ -15,9 +15,7 @@ import ( "time" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/base/rng" - "github.com/safing/portmaster/base/run" ) var ( diff --git a/base/runtime/modules_integration.go b/base/runtime/modules_integration.go index c85ac330f..de85d0b9d 100644 --- a/base/runtime/modules_integration.go +++ b/base/runtime/modules_integration.go @@ -6,8 +6,6 @@ import ( "github.com/safing/portmaster/base/database" "github.com/safing/portmaster/base/database/record" - "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" ) var modulesIntegrationUpdatePusher func(...record.Record) @@ -18,9 +16,9 @@ func startModulesIntegration() (err error) { return err } - if !modules.SetEventSubscriptionFunc(pushModuleEvent) { - log.Warningf("runtime: failed to register the modules event subscription function") - } + // if !modules.SetEventSubscriptionFunc(pushModuleEvent) { + // log.Warningf("runtime: failed to register the modules event subscription function") + // } return nil } diff --git a/base/template/module_test.go b/base/template/module_test.go index 2a41f02b9..824c195ef 100644 --- a/base/template/module_test.go +++ b/base/template/module_test.go @@ -7,7 +7,6 @@ import ( _ "github.com/safing/portmaster/base/database/dbmodule" "github.com/safing/portmaster/base/dataroot" - "github.com/safing/portmaster/base/modules" ) func TestMain(m *testing.M) { diff --git a/base/utils/debug/debug.go b/base/utils/debug/debug.go index 06ac7b937..0446f9d90 100644 --- a/base/utils/debug/debug.go +++ b/base/utils/debug/debug.go @@ -4,12 +4,10 @@ import ( "bytes" "fmt" "runtime/pprof" - "strings" "time" "github.com/safing/portmaster/base/info" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" ) // Info gathers debugging information and stores everything in a buffer in @@ -114,17 +112,17 @@ func (di *Info) AddGoroutineStack() { // AddLastReportedModuleError adds the last reported module error, if one exists. func (di *Info) AddLastReportedModuleError() { - me := modules.GetLastReportedError() - if me == nil { - di.AddSection("No Module Error", NoFlags) - return - } - - di.AddSection( - fmt.Sprintf("%s Module Error", strings.Title(me.ModuleName)), //nolint:staticcheck - UseCodeSection, - me.Format(), - ) + // me := modules.GetLastReportedError() + // if me == nil { + // di.AddSection("No Module Error", NoFlags) + // return + // } + + // di.AddSection( + // fmt.Sprintf("%s Module Error", strings.Title(me.ModuleName)), //nolint:staticcheck + // UseCodeSection, + // me.Format(), + // ) } // AddLastUnexpectedLogs adds the last 10 unexpected log lines, if any. diff --git a/cmds/hub/main.go b/cmds/hub/main.go index 4b67299f7..dada02d05 100644 --- a/cmds/hub/main.go +++ b/cmds/hub/main.go @@ -8,8 +8,6 @@ import ( "github.com/safing/portmaster/base/info" "github.com/safing/portmaster/base/metrics" - "github.com/safing/portmaster/base/modules" - "github.com/safing/portmaster/base/run" _ "github.com/safing/portmaster/service/core/base" _ "github.com/safing/portmaster/service/ui" "github.com/safing/portmaster/service/updates" diff --git a/cmds/notifier/main.go b/cmds/notifier/main.go index ef4f0e603..164aeb003 100644 --- a/cmds/notifier/main.go +++ b/cmds/notifier/main.go @@ -21,7 +21,6 @@ import ( "github.com/safing/portmaster/base/dataroot" "github.com/safing/portmaster/base/info" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/base/updater" "github.com/safing/portmaster/base/utils" "github.com/safing/portmaster/service/updates/helper" diff --git a/cmds/observation-hub/apprise.go b/cmds/observation-hub/apprise.go index bb16685cd..501c06131 100644 --- a/cmds/observation-hub/apprise.go +++ b/cmds/observation-hub/apprise.go @@ -14,7 +14,6 @@ import ( "github.com/safing/portmaster/base/apprise" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/service/intel/geoip" ) diff --git a/cmds/observation-hub/main.go b/cmds/observation-hub/main.go index dfa7c582e..cf3c11144 100644 --- a/cmds/observation-hub/main.go +++ b/cmds/observation-hub/main.go @@ -8,8 +8,6 @@ import ( "github.com/safing/portmaster/base/api" "github.com/safing/portmaster/base/info" "github.com/safing/portmaster/base/metrics" - "github.com/safing/portmaster/base/modules" - "github.com/safing/portmaster/base/run" "github.com/safing/portmaster/service/updates" "github.com/safing/portmaster/service/updates/helper" "github.com/safing/portmaster/spn/captain" diff --git a/cmds/observation-hub/observe.go b/cmds/observation-hub/observe.go index ca4e64038..cec4c687c 100644 --- a/cmds/observation-hub/observe.go +++ b/cmds/observation-hub/observe.go @@ -15,7 +15,6 @@ import ( "github.com/safing/portmaster/base/database" "github.com/safing/portmaster/base/database/query" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/spn/captain" "github.com/safing/portmaster/spn/navigator" ) diff --git a/cmds/portmaster-core/main.go b/cmds/portmaster-core/main.go index 764e8ef33..11a865c78 100644 --- a/cmds/portmaster-core/main.go +++ b/cmds/portmaster-core/main.go @@ -3,18 +3,17 @@ package main import ( "fmt" - "os" "runtime" "github.com/safing/portmaster/base/info" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/metrics" - "github.com/safing/portmaster/base/run" + "github.com/safing/portmaster/service" + "github.com/safing/portmaster/service/core/base" "github.com/safing/portmaster/service/updates" "github.com/safing/portmaster/spn/conf" // Include packages here. - _ "github.com/safing/portmaster/base/modules/subsystems" _ "github.com/safing/portmaster/service/core" _ "github.com/safing/portmaster/service/firewall" _ "github.com/safing/portmaster/service/nameserver" @@ -38,6 +37,26 @@ func main() { // enable SPN client mode conf.EnableClient(true) - // start - os.Exit(run.Run()) + // Create + instance, err := service.New("2.0.0", &service.ServiceConfig{ + ShutdownFunc: func(exitCode int) { + fmt.Printf("ExitCode: %d\n", exitCode) + }, + }) + if err != nil { + fmt.Printf("error creating an instance: %s\n", err) + return + } + // Prep + err = base.GlobalPrep() + if err != nil { + fmt.Printf("global prep failed: %s\n", err) + return + } + // Start + err = instance.Group.Start() + if err != nil { + fmt.Printf("instance start failed: %s\n", err) + return + } } diff --git a/go.mod b/go.mod index aeb95a003..a3ea7e29c 100644 --- a/go.mod +++ b/go.mod @@ -7,25 +7,36 @@ replace github.com/tc-hib/winres => github.com/dhaavi/winres v0.2.2 require ( fyne.io/systray v1.10.0 + github.com/VictoriaMetrics/metrics v1.33.1 github.com/Xuanwo/go-locale v1.1.0 + github.com/aead/serpent v0.0.0-20160714141033-fba169763ea6 github.com/agext/levenshtein v1.2.3 + github.com/armon/go-radix v1.0.0 github.com/awalterschulze/gographviz v2.0.3+incompatible + github.com/bluele/gcache v0.0.2 github.com/brianvoe/gofakeit v3.18.0+incompatible github.com/cilium/ebpf v0.15.0 github.com/coreos/go-iptables v0.7.0 + github.com/davecgh/go-spew v1.1.1 + github.com/dgraph-io/badger v1.6.2 github.com/dhaavi/go-notify v0.0.0-20190209221809-c404b1f22435 github.com/florianl/go-conntrack v0.4.0 github.com/florianl/go-nfqueue v1.3.2 github.com/fogleman/gg v1.3.0 + github.com/fxamacker/cbor/v2 v2.6.0 github.com/ghodss/yaml v1.0.0 github.com/godbus/dbus/v5 v5.1.0 + github.com/gofrs/uuid v4.4.0+incompatible github.com/google/gopacket v1.1.19 github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 + github.com/gorilla/mux v1.8.1 + github.com/gorilla/websocket v1.5.1 github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-version v1.7.0 github.com/jackc/puddle/v2 v2.2.1 github.com/mat/besticon v3.12.0+incompatible github.com/miekg/dns v1.1.59 + github.com/mitchellh/copystructure v1.2.0 github.com/mitchellh/go-server-timing v1.0.1 github.com/mr-tron/base58 v1.2.0 github.com/oschwald/maxminddb-golang v1.12.0 @@ -34,6 +45,7 @@ require ( github.com/safing/jess v0.3.3 github.com/safing/portbase v0.19.5 github.com/safing/portmaster-android/go v0.0.0-20230830120134-3226ceac3bec + github.com/seehuhn/fortuna v1.0.1 github.com/shirou/gopsutil v3.21.11+incompatible github.com/spf13/cobra v1.8.0 github.com/spkg/zipfs v0.7.1 @@ -41,8 +53,12 @@ require ( github.com/tannerryan/ring v1.1.2 github.com/tc-hib/winres v0.3.1 github.com/tevino/abool v1.2.0 + github.com/tidwall/gjson v1.17.1 + github.com/tidwall/sjson v1.2.5 github.com/umahmood/haversine v0.0.0-20151105152445-808ab04add26 github.com/vincent-petithory/dataurl v1.0.0 + github.com/vmihailenco/msgpack/v5 v5.4.1 + go.etcd.io/bbolt v1.3.10 golang.org/x/exp v0.0.0-20240525044651-4c93da0ed11d golang.org/x/image v0.16.0 golang.org/x/net v0.25.0 @@ -53,28 +69,24 @@ require ( ) require ( - github.com/VictoriaMetrics/metrics v1.33.1 // indirect + github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 // indirect github.com/aead/ecdh v0.2.0 // indirect - github.com/aead/serpent v0.0.0-20160714141033-fba169763ea6 // indirect github.com/alessio/shellescape v1.4.2 // indirect - github.com/armon/go-radix v1.0.0 // indirect - github.com/bluele/gcache v0.0.2 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/danieljoos/wincred v1.2.1 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgraph-io/ristretto v0.1.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fxamacker/cbor v1.5.1 // indirect - github.com/fxamacker/cbor/v2 v2.6.0 // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/godbus/dbus v4.1.0+incompatible // indirect - github.com/gofrs/uuid v4.4.0+incompatible // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f // indirect + github.com/golang/glog v1.2.0 // indirect + github.com/golang/protobuf v1.5.3 // indirect github.com/google/btree v1.1.2 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/google/uuid v1.6.0 // indirect - github.com/gorilla/mux v1.8.1 // indirect - github.com/gorilla/websocket v1.5.1 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/josharian/native v1.1.0 // indirect @@ -82,36 +94,32 @@ require ( github.com/mattn/go-isatty v0.0.20 // indirect github.com/mdlayher/netlink v1.7.2 // indirect github.com/mdlayher/socket v0.5.1 // indirect - github.com/mitchellh/copystructure v1.2.0 // indirect github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 // indirect + github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/satori/go.uuid v1.2.0 // indirect - github.com/seehuhn/fortuna v1.0.1 // indirect github.com/seehuhn/sha256d v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/tidwall/gjson v1.17.1 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect - github.com/tidwall/sjson v1.2.5 // indirect github.com/tklauser/go-sysconf v0.3.14 // indirect github.com/tklauser/numcpus v0.8.0 // indirect github.com/valyala/fastrand v1.1.0 // indirect github.com/valyala/histogram v1.2.0 // indirect - github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/x448/float16 v0.8.4 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect github.com/zalando/go-keyring v0.2.4 // indirect github.com/zeebo/blake3 v0.2.3 // indirect - go.etcd.io/bbolt v1.3.10 // indirect golang.org/x/crypto v0.23.0 // indirect golang.org/x/mod v0.17.0 // indirect golang.org/x/text v0.15.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.21.0 // indirect + google.golang.org/protobuf v1.32.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gvisor.dev/gvisor v0.0.0-20240524212851-a244eff8ad49 // indirect modernc.org/libc v1.50.9 // indirect diff --git a/go.sum b/go.sum index 466b106ae..a7eb61a2b 100644 --- a/go.sum +++ b/go.sum @@ -5,6 +5,7 @@ github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 h1:cTp8I5+VIo github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v0.4.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= +github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/VictoriaMetrics/metrics v1.33.1 h1:CNV3tfm2Kpv7Y9W3ohmvqgFWPR55tV2c7M2U6OIo+UM= github.com/VictoriaMetrics/metrics v1.33.1/go.mod h1:r7hveu6xMdUACXvB8TYdAj8WEsKzWB0EkpJN+RDtOf8= github.com/Xuanwo/go-locale v1.1.0 h1:51gUxhxl66oXAjI9uPGb2O0qwPECpriKQb2hl35mQkg= @@ -17,6 +18,7 @@ github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7l github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558= github.com/alessio/shellescape v1.4.2 h1:MHPfaU+ddJ0/bYWpgIeUnQUqKrlJ1S7BfEYPM4uEoM0= github.com/alessio/shellescape v1.4.2/go.mod h1:PZAiSCk0LJaZkiCSkPv8qIobYglO3FPpyFjDCtHLS30= +github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= github.com/armon/go-radix v1.0.0 h1:F4z6KzEeeQIMeLFa97iZU6vupzoecKdU5TX24SNppXI= github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/awalterschulze/gographviz v2.0.3+incompatible h1:9sVEXJBJLwGX7EQVhLm2elIKCm7P2YHFC8v6096G09E= @@ -26,14 +28,20 @@ github.com/bluele/gcache v0.0.2/go.mod h1:m15KV+ECjptwSPxKhOhQoAFQVtUFjTVkc3H8o0 github.com/bradfitz/gomemcache v0.0.0-20170208213004-1952afaa557d/go.mod h1:PmM6Mmwb0LSuEubjR8N7PtNe1KxZLtOUHtbeikc5h60= github.com/brianvoe/gofakeit v3.18.0+incompatible h1:wDOmHc9DLG4nRjUVVaxA+CEglKOW72Y5+4WNxUIkjM8= github.com/brianvoe/gofakeit v3.18.0+incompatible/go.mod h1:kfwdRA90vvNhPutZWfH7WPaDzUjz+CZFqG+rPkOjGOc= +github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cilium/ebpf v0.5.0/go.mod h1:4tRaxcgiL706VnOzHOdBlY8IEAIdxINsQBcU4xJJXRs= github.com/cilium/ebpf v0.7.0/go.mod h1:/oI2+1shJiTGAMgl6/RgJr36Eo1jzrRcAWbcXO2usCA= github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk= github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso= +github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= +github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk= github.com/coreos/go-iptables v0.7.0 h1:XWM3V+MPRr5/q51NuWSgU0fqMad64Zyxs8ZUoMsamr8= github.com/coreos/go-iptables v0.7.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= +github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= +github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/danieljoos/wincred v1.2.1 h1:dl9cBrupW8+r5250DYkYxocLeZ1Y4vB1kxgtjxw8GQs= github.com/danieljoos/wincred v1.2.1/go.mod h1:uGaFL9fDn3OLTvzCGulzE+SzjEe5NGlh5FdCcyfPwps= @@ -42,12 +50,16 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgraph-io/badger v1.6.2 h1:mNw0qs90GVgGGWylh0umH5iag1j6n/PeJtNvL6KY/x8= github.com/dgraph-io/badger v1.6.2/go.mod h1:JW2yswe3V058sS0kZ2h/AXeDSqFjxnZcRrVH//y2UQE= +github.com/dgraph-io/ristretto v0.0.2/go.mod h1:KPxhHT9ZxKefz+PCeOGsrHpl1qZ7i70dGTu2u+Ahh6E= github.com/dgraph-io/ristretto v0.1.1 h1:6CWw5tJNgpegArSHpNHJKldNeq03FQCwYvfMVWajOK8= github.com/dgraph-io/ristretto v0.1.1/go.mod h1:S1GPSBCYCIhmVNfcth17y2zZtQT6wzkzgwUve0VDWWA= +github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA= +github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dhaavi/go-notify v0.0.0-20190209221809-c404b1f22435 h1:AnwbdEI8eV3GzLM3SlrJlYmYa6OB5X8RwY4A8QJOCP0= github.com/dhaavi/go-notify v0.0.0-20190209221809-c404b1f22435/go.mod h1:EMJ8XWTopp8OLRBMUm9vHE8Wn48CNpU21HM817OKNrc= github.com/dhaavi/winres v0.2.2 h1:SUago7FwhgLSMyDdeuV6enBZ+ZQSl0KwcnbWzvlfBls= github.com/dhaavi/winres v0.2.2/go.mod h1:1NTs+/DtKP1BplIL1+XQSoq4X1PUfLczexS7gf3x9T4= +github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/felixge/httpsnoop v1.0.0/go.mod h1:3+D9sFq0ahK/JeJPhCBUV1xlf4/eIYrUQaxulT0VzX8= @@ -61,6 +73,7 @@ github.com/fogleman/gg v1.3.0 h1:/7zJX8F6AaYQc57WQCyN9cAIz+4bCJGO9B+dyW29am8= github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= github.com/frankban/quicktest v1.11.3/go.mod h1:wRf/ReqHper53s+kmmSZizM8NamnL3IM0I9ntUbOk+k= github.com/fsnotify/fsnotify v1.4.3-0.20170329110642-4da3e2cfbabc/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fxamacker/cbor v1.5.1 h1:XjQWBgdmQyqimslUh5r4tUGmoqzHmBFQOImkWGi2awg= github.com/fxamacker/cbor v1.5.1/go.mod h1:3aPGItF174ni7dDzd6JZ206H8cmr4GDNBGpPa971zsU= github.com/fxamacker/cbor/v2 v2.6.0 h1:sU6J2usfADwWlYDAFhZBQ6TnLFBHxgesMrQfQgk1tWA= @@ -86,10 +99,12 @@ github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGw github.com/golang/gddo v0.0.0-20180823221919-9d8ff1c67be5/go.mod h1:xEhNfoBDX1hzLm2Nf80qUvZ2sVwoMZ8d6IE2SrsQfh4= github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f h1:16RtHeWGkJMc80Etb8RPCcKevXGldr57+LOyZt8zOlg= github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f/go.mod h1:ijRvpgDJDI262hYq/IQVYgf8hd8IHUs93Ol0kvMBAx4= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v1.2.0 h1:uCdmnmatrKCgMBlM4rMuJZWOkPDqdbZPnrMXDY4gI68= github.com/golang/glog v1.2.0/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= github.com/golang/lint v0.0.0-20170918230701-e5d664eb928e/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= @@ -129,7 +144,9 @@ github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9 github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKeRZfjY= github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hashicorp/hcl v0.0.0-20170914154624-68e816d1c783/go.mod h1:oZtUIOe8dh44I2q6ScRibXws4Ajl+d+nod3AaR9vL5w= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/inconshreveable/log15 v0.0.0-20170622235902-74a0988b5f80/go.mod h1:cOaXtrgN4ScfRrD9Bre7U1thNq5RtJ8ZoP4iXVGRj6o= +github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= @@ -161,6 +178,7 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/magiconair/properties v1.7.4-0.20170902060319-8d7837e64d3c/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= +github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/mat/besticon v3.12.0+incompatible h1:1KTD6wisfjfnX+fk9Kx/6VEZL+MAW1LhCkL9Q47H9Bg= github.com/mat/besticon v3.12.0+incompatible/go.mod h1:mA1auQYHt6CW5e7L9HJLmqVQC8SzNk2gVwouO0AbiEU= github.com/mattn/go-colorable v0.0.10-0.20170816031813-ad5389df28cd/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= @@ -194,9 +212,11 @@ github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs= github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk= github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= +github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-server-timing v1.0.1 h1:f00/aIe8T3MrnLhQHu3tSWvnwc5GV/p5eutuu3hF/tE= github.com/mitchellh/go-server-timing v1.0.1/go.mod h1:Mo6GKi9FSLwWFAMn3bqVPWe20y5ri5QGQuO9D9MCOxk= github.com/mitchellh/mapstructure v0.0.0-20170523030023-d0303fe80992/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= +github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ= github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o= @@ -208,6 +228,8 @@ github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S github.com/oschwald/maxminddb-golang v1.12.0 h1:9FnTOD0YOhP7DGxGsq4glzpGy5+w7pq50AS6wALUMYs= github.com/oschwald/maxminddb-golang v1.12.0/go.mod h1:q0Nob5lTCqyQ8WT6FYgS1L7PXKVVbgiymefNwIjPzgY= github.com/pelletier/go-toml v1.0.1-0.20170904195809-1d6b12b7cb29/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= +github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -220,6 +242,7 @@ github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDN github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= github.com/rot256/pblind v0.0.0-20231024115251-cd3f239f28c1 h1:vfAp3Jbca7Vt8axzmkS5M/RtFJmj0CKmrtWAlHtesaA= github.com/rot256/pblind v0.0.0-20231024115251-cd3f239f28c1/go.mod h1:2x8fbm9T+uTl919COhEVHKGkve1DnkrEnDbtGptZuW8= +github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/safing/jess v0.3.3 h1:0U0bWdO0sFCgox+nMOqISFrnJpVmi+VFOW1xdX6q3qw= github.com/safing/jess v0.3.3/go.mod h1:t63qHB+4xd1HIv9MKN/qI2rc7ytvx7d6l4hbX7zxer0= @@ -239,20 +262,29 @@ github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykE github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v1.6.7 h1:I6tZjLXD2Q1kjvNbIzB1wvQBsXmKXiVrhpRE8ZjP5jY= github.com/smartystreets/goconvey v1.6.7/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= +github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= +github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spf13/afero v0.0.0-20170901052352-ee1bd8ee15a1/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= +github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= github.com/spf13/cast v1.1.0/go.mod h1:r2rcYCSwa1IExKTDiTfzaxqT2FNHs8hODu4LnUfgKEg= +github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= +github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= github.com/spf13/jwalterweatherman v0.0.0-20170901151539-12bd96e66386/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= +github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= github.com/spf13/pflag v1.0.1-0.20170901120850-7aff26db30c1/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.0.0/go.mod h1:A8kyI5cUJhb8N+3pkfONlcEcZbueH6nhAm0Fq7SrnBM= +github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= github.com/spkg/zipfs v0.7.1 h1:+2X5lvNHTybnDMQZAIHgedRXZK1WXdc+94R/P5v2XWE= github.com/spkg/zipfs v0.7.1/go.mod h1:48LW+/Rh1G7aAav1ew1PdlYn52T+LM+ARmSHfDNJvg8= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -277,6 +309,7 @@ github.com/tklauser/go-sysconf v0.3.14 h1:g5vzr9iPFFz24v2KZXs/pvpvh8/V9Fw6vQK5ZZ github.com/tklauser/go-sysconf v0.3.14/go.mod h1:1ym4lWMLUOhuBOPGtRcJm7tEGX4SCYNEEEtghGG/8uY= github.com/tklauser/numcpus v0.8.0 h1:Mx4Wwe/FjZLeQsK/6kt2EOepwwSl7SmJrK5bV/dXYgY= github.com/tklauser/numcpus v0.8.0/go.mod h1:ZJZlAY+dmR4eut8epnzf0u/VwodKmryxR8txiloSqBE= +github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= github.com/umahmood/haversine v0.0.0-20151105152445-808ab04add26 h1:UFHFmFfixpmfRBcxuu+LA9l8MdURWVdVNUHxO5n1d2w= github.com/umahmood/haversine v0.0.0-20151105152445-808ab04add26/go.mod h1:IGhd0qMDsUa9acVjsbsT7bu3ktadtGOHI79+idTew/M= github.com/valyala/fastrand v1.1.0 h1:f+5HkLW4rsgzdNoleUOB69hyT9IlD2ZQh9GyDMfb5G8= @@ -292,6 +325,7 @@ github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAh github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= +github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= @@ -306,6 +340,7 @@ github.com/zeebo/pcg v1.0.1 h1:lyqfGeWiv4ahac6ttHs+I5hwtH/+1mrhlCtVNQM2kHo= github.com/zeebo/pcg v1.0.1/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4= go.etcd.io/bbolt v1.3.10 h1:+BqfJTcCzTItrop8mq/lbzL8wSGtj94UO/3U31shqG0= go.etcd.io/bbolt v1.3.10/go.mod h1:bK3UQLPJZly7IlNmV7uVHJDxfe5aK9Ll93e/74Y9oEQ= +golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -353,10 +388,12 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190411185658-b44545bcd369/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -384,6 +421,7 @@ golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211210111614-af8b64212486/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -417,6 +455,8 @@ google.golang.org/api v0.0.0-20170921000349-586095a6e407/go.mod h1:4mhQ8q/RsB7i+ google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/genproto v0.0.0-20170918111702-1e559d0a00ee/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/grpc v1.2.1-0.20170921194603-d4b75ebd4f9f/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I= google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/service/broadcasts/api.go b/service/broadcasts/api.go index 4bee5195c..ea6b01cb7 100644 --- a/service/broadcasts/api.go +++ b/service/broadcasts/api.go @@ -16,7 +16,6 @@ func registerAPIEndpoints() error { if err := api.RegisterEndpoint(api.Endpoint{ Path: `broadcasts/matching-data`, Read: api.PermitAdmin, - BelongsTo: module, StructFunc: handleMatchingData, Name: "Get Broadcast Notifications Matching Data", Description: "Returns the data used by the broadcast notifications to match the instance.", @@ -28,7 +27,6 @@ func registerAPIEndpoints() error { Path: `broadcasts/reset-state`, Write: api.PermitAdmin, WriteMethod: http.MethodPost, - BelongsTo: module, ActionFunc: handleResetState, Name: "Resets the Broadcast Notification States", Description: "Delete the cache of Broadcast Notifications, making them appear again.", @@ -40,7 +38,6 @@ func registerAPIEndpoints() error { Path: `broadcasts/simulate`, Write: api.PermitAdmin, WriteMethod: http.MethodPost, - BelongsTo: module, ActionFunc: handleSimulate, Name: "Simulate Broadcast Notifications", Description: "Test broadcast notifications by sending a valid source file in the body.", diff --git a/service/broadcasts/notify.go b/service/broadcasts/notify.go index c425a4677..2ab6c9645 100644 --- a/service/broadcasts/notify.go +++ b/service/broadcasts/notify.go @@ -66,7 +66,7 @@ type BroadcastNotification struct { repeatDuration time.Duration } -func broadcastNotify(ctx mgr.WorkerCtx) error { +func broadcastNotify(ctx *mgr.WorkerCtx) error { // Get broadcast notifications file, load it from disk and parse it. broadcastsResource, err := updates.GetFile(broadcastsResourcePath) if err != nil { @@ -212,9 +212,10 @@ func handleBroadcast(bn *BroadcastNotification, matchingDataAccessor accessor.Ac n.Save() // Attach to module to raise more awareness. - if bn.AttachToModule { - n.AttachToModule(module) - } + // TODO(vladimir): is there a need for this? + // if bn.AttachToModule { + // n.AttachToModule(module) + // } return nil } diff --git a/service/compat/api.go b/service/compat/api.go index 998475fac..69471d680 100644 --- a/service/compat/api.go +++ b/service/compat/api.go @@ -8,7 +8,6 @@ func registerAPIEndpoints() error { if err := api.RegisterEndpoint(api.Endpoint{ Path: "compat/self-check", Read: api.PermitUser, - BelongsTo: module, ActionFunc: selfcheckViaAPI, Name: "Run Integration Self-Check", Description: "Runs a couple integration self-checks in order to see if the system integration works.", diff --git a/service/compat/notify.go b/service/compat/notify.go index 4f93d4195..ce2a949c7 100644 --- a/service/compat/notify.go +++ b/service/compat/notify.go @@ -1,7 +1,6 @@ package compat import ( - "context" "fmt" "net" "strings" @@ -10,7 +9,6 @@ import ( "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/process" @@ -139,10 +137,11 @@ func (issue *systemIssue) notify(err error) { notifications.Notify(n) systemIssueNotification = n - n.AttachToModule(module) + // n.AttachToModule(module) // Report the raw error as module error. - module.NewErrorMessage("selfcheck", err).Report() + // FIXME(vladimir): Is there a need for this kind of error reporting? + // module.NewErrorMessage("selfcheck", err).Report() } func resetSystemIssue() { @@ -215,7 +214,7 @@ func (issue *appIssue) notify(proc *process.Process) { notifications.Notify(n) // Set warning on profile. - module.StartWorker("set app compat warning", func(ctx context.Context) error { + module.mgr.Go("set app compat warning", func(ctx *mgr.WorkerCtx) error { var changed bool func() { diff --git a/service/compat/selfcheck.go b/service/compat/selfcheck.go index 872053c8a..27efd4881 100644 --- a/service/compat/selfcheck.go +++ b/service/compat/selfcheck.go @@ -12,6 +12,7 @@ import ( "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/rng" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/network/packet" "github.com/safing/portmaster/service/resolver" @@ -130,7 +131,7 @@ func selfcheck(ctx context.Context) (issue *systemIssue, err error) { } // Start worker for the DNS lookup. - module.StartWorker("dns check lookup", func(_ context.Context) error { + module.mgr.Go("dns check lookup", func(_ *mgr.WorkerCtx) error { ips, err := net.LookupIP(randomSubdomain + DNSCheckInternalDomainScope) if err == nil && len(ips) > 0 { dnsCheckReturnedIP = ips[0] diff --git a/service/core/api.go b/service/core/api.go index bdacb6a14..2c25262f0 100644 --- a/service/core/api.go +++ b/service/core/api.go @@ -12,7 +12,6 @@ import ( "github.com/safing/portmaster/base/api" "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/base/rng" "github.com/safing/portmaster/base/utils/debug" @@ -54,7 +53,6 @@ func registerAPIEndpoints() error { if err := api.RegisterEndpoint(api.Endpoint{ Path: "debug/core", Read: api.PermitAnyone, - BelongsTo: module, DataFunc: debugInfo, Name: "Get Debug Information", Description: "Returns network debugging information, similar to debug/info, but with system status data.", @@ -71,7 +69,6 @@ func registerAPIEndpoints() error { if err := api.RegisterEndpoint(api.Endpoint{ Path: "app/auth", Read: api.PermitAnyone, - BelongsTo: module, StructFunc: authorizeApp, Name: "Request an authentication token with a given set of permissions. The user will be prompted to either authorize or deny the request. Used for external or third-party tool integrations.", Parameters: []api.Parameter{ @@ -103,7 +100,6 @@ func registerAPIEndpoints() error { if err := api.RegisterEndpoint(api.Endpoint{ Path: "app/profile", Read: api.PermitUser, - BelongsTo: module, StructFunc: getMyProfile, Name: "Get the ID of the calling profile", }); err != nil { @@ -118,7 +114,8 @@ func shutdown(_ *api.Request) (msg string, err error) { log.Warning("core: user requested shutdown via action") // Do not run in worker, as this would block itself here. - go modules.Shutdown() //nolint:errcheck + // TODO(vladimir): replace with something better + go ShutdownHook() //nolint:errcheck return "shutdown initiated", nil } diff --git a/service/core/base/global.go b/service/core/base/global.go index a28827a39..727fbd3de 100644 --- a/service/core/base/global.go +++ b/service/core/base/global.go @@ -8,9 +8,10 @@ import ( "github.com/safing/portmaster/base/api" "github.com/safing/portmaster/base/dataroot" "github.com/safing/portmaster/base/info" - "github.com/safing/portmaster/base/modules" ) +var ErrCleanExit = errors.New("clean exit requested") + // Default Values (changeable for testing). var ( DefaultAPIListenAddress = "127.0.0.1:817" @@ -25,10 +26,10 @@ func init() { flag.StringVar(&databaseDir, "db", "", "alias to --data (deprecated)") flag.BoolVar(&showVersion, "version", false, "show version and exit") - modules.SetGlobalPrepFn(globalPrep) + // modules.SetGlobalPrepFn(globalPrep) } -func globalPrep() error { +func GlobalPrep() error { // check if meta info is ok err := info.CheckVersion() if err != nil { @@ -38,7 +39,7 @@ func globalPrep() error { // print version if showVersion { fmt.Println(info.FullVersion()) - return modules.ErrCleanExit + return ErrCleanExit } // check data root diff --git a/service/core/base/logs.go b/service/core/base/logs.go index 91870e0e7..b78c75e17 100644 --- a/service/core/base/logs.go +++ b/service/core/base/logs.go @@ -1,7 +1,6 @@ package base import ( - "context" "errors" "os" "path/filepath" @@ -10,7 +9,7 @@ import ( "github.com/safing/portmaster/base/dataroot" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/service/mgr" ) const ( @@ -20,12 +19,13 @@ const ( ) func registerLogCleaner() { - module.NewTask("log cleaner", logCleaner). - Repeat(24 * time.Hour). - Schedule(time.Now().Add(15 * time.Minute)) + module.mgr.Delay("log cleaner delay", 15*time.Minute, func(w *mgr.WorkerCtx) error { + module.mgr.Repeat("log cleaner", 24*time.Hour, logCleaner) + return nil + }) } -func logCleaner(_ context.Context, _ *modules.Task) error { +func logCleaner(_ *mgr.WorkerCtx) error { ageThreshold := time.Now().Add(-logTTL) return filepath.Walk( diff --git a/service/core/base/module.go b/service/core/base/module.go index 581e74b4b..165c3823c 100644 --- a/service/core/base/module.go +++ b/service/core/base/module.go @@ -1,31 +1,22 @@ package base import ( + "errors" + "sync/atomic" + _ "github.com/safing/portmaster/base/config" _ "github.com/safing/portmaster/base/metrics" - "github.com/safing/portmaster/base/modules" _ "github.com/safing/portmaster/base/rng" + "github.com/safing/portmaster/service/mgr" ) -var module *modules.Module - -func init() { - module = modules.Register("base", nil, start, nil, "database", "config", "rng", "metrics") - - // For prettier subsystem graph, printed with --print-subsystem-graph - /* - subsystems.Register( - "base", - "Base", - "THE GROUND.", - baseModule, - "", - nil, - ) - */ +type Base struct { + mgr *mgr.Manager + instance instance } -func start() error { +func (b *Base) Start(m *mgr.Manager) error { + b.mgr = m startProfiling() if err := registerDatabases(); err != nil { @@ -36,3 +27,26 @@ func start() error { return nil } + +func (b *Base) Stop(m *mgr.Manager) error { + return nil +} + +var ( + module *Base + shimLoaded atomic.Bool +) + +// New returns a new Base module. +func New(instance instance) (*Base, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &Base{ + instance: instance, + } + return module, nil +} + +type instance interface{} diff --git a/service/core/base/profiling.go b/service/core/base/profiling.go index bae54645f..96d21eec4 100644 --- a/service/core/base/profiling.go +++ b/service/core/base/profiling.go @@ -1,11 +1,12 @@ package base import ( - "context" "flag" "fmt" "os" "runtime/pprof" + + "github.com/safing/portmaster/service/mgr" ) var cpuProfile string @@ -16,11 +17,11 @@ func init() { func startProfiling() { if cpuProfile != "" { - module.StartWorker("cpu profiler", cpuProfiler) + module.mgr.Go("cpu profiler", cpuProfiler) } } -func cpuProfiler(ctx context.Context) error { +func cpuProfiler(ctx *mgr.WorkerCtx) error { f, err := os.Create(cpuProfile) if err != nil { return fmt.Errorf("could not create CPU profile: %w", err) diff --git a/service/core/core.go b/service/core/core.go index fbeee9bc1..c1554e4bf 100644 --- a/service/core/core.go +++ b/service/core/core.go @@ -1,15 +1,16 @@ package core import ( + "errors" "flag" "fmt" + "sync/atomic" "time" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/metrics" - "github.com/safing/portmaster/base/modules" - "github.com/safing/portmaster/base/modules/subsystems" _ "github.com/safing/portmaster/service/broadcasts" + "github.com/safing/portmaster/service/mgr" _ "github.com/safing/portmaster/service/netenv" _ "github.com/safing/portmaster/service/netquery" _ "github.com/safing/portmaster/service/status" @@ -23,22 +24,40 @@ const ( eventRestart = "restart" ) -var ( - module *modules.Module +type Core struct { + instance instance - disableShutdownEvent bool -) + EventShutdown *mgr.EventMgr[struct{}] + EventRestart *mgr.EventMgr[struct{}] +} + +func (c *Core) Start(m *mgr.Manager) error { + c.EventShutdown = mgr.NewEventMgr[struct{}]("shutdown", m) + c.EventRestart = mgr.NewEventMgr[struct{}]("restart", m) + + if err := prep(); err != nil { + return err + } + + return start() +} + +func (c *Core) Stop(m *mgr.Manager) error { + return nil +} + +var disableShutdownEvent bool func init() { - module = modules.Register("core", prep, start, nil, "base", "subsystems", "status", "updates", "api", "notifications", "ui", "netenv", "network", "netquery", "interception", "compat", "broadcasts", "sync") - subsystems.Register( - "core", - "Core", - "Base Structure and System Integration", - module, - "config:core/", - nil, - ) + // module = modules.Register("core", prep, start, nil, "base", "subsystems", "status", "updates", "api", "notifications", "ui", "netenv", "network", "netquery", "interception", "compat", "broadcasts", "sync") + // subsystems.Register( + // "core", + // "Core", + // "Base Structure and System Integration", + // module, + // "config:core/", + // nil, + // ) flag.BoolVar( &disableShutdownEvent, @@ -47,12 +66,10 @@ func init() { "disable shutdown event to keep app and notifier open when core shuts down", ) - modules.SetGlobalShutdownFn(shutdownHook) + // modules.SetGlobalShutdownFn(shutdownHook) } func prep() error { - registerEvents() - // init config err := registerConfig() if err != nil { @@ -79,22 +96,37 @@ func start() error { return nil } -func registerEvents() { - module.RegisterEvent(eventShutdown, true) - module.RegisterEvent(eventRestart, true) -} - -func shutdownHook() { +func ShutdownHook() { // Notify everyone of the restart/shutdown. if !updates.IsRestarting() { // Only trigger shutdown event if not disabled. if !disableShutdownEvent { - module.TriggerEvent(eventShutdown, nil) + module.EventShutdown.Submit(struct{}{}) } } else { - module.TriggerEvent(eventRestart, nil) + module.EventRestart.Submit(struct{}{}) } // Wait a bit for the event to propagate. + // TODO(vladimir): is this necessary? time.Sleep(100 * time.Millisecond) } + +var ( + module *Core + shimLoaded atomic.Bool +) + +// New returns a new NetEnv module. +func New(instance instance) (*Core, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &Core{ + instance: instance, + } + return module, nil +} + +type instance interface{} diff --git a/service/core/pmtesting/testing.go b/service/core/pmtesting/testing.go index 131410610..41eae4b60 100644 --- a/service/core/pmtesting/testing.go +++ b/service/core/pmtesting/testing.go @@ -26,7 +26,6 @@ import ( _ "github.com/safing/portmaster/base/database/storage/hashmap" "github.com/safing/portmaster/base/dataroot" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/service/core/base" ) diff --git a/service/firewall/interception/interception_linux.go b/service/firewall/interception/interception_linux.go index 66ca5b7eb..24109f48b 100644 --- a/service/firewall/interception/interception_linux.go +++ b/service/firewall/interception/interception_linux.go @@ -1,12 +1,12 @@ package interception import ( - "context" "time" bandwidth "github.com/safing/portmaster/service/firewall/interception/ebpf/bandwidth" conn_listener "github.com/safing/portmaster/service/firewall/interception/ebpf/connection_listener" "github.com/safing/portmaster/service/firewall/interception/nfq" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/network" "github.com/safing/portmaster/service/network/packet" ) @@ -20,13 +20,13 @@ func startInterception(packets chan packet.Packet) error { } // Start ebpf new connection listener. - module.StartServiceWorker("ebpf connection listener", 0, func(ctx context.Context) error { - return conn_listener.ConnectionListenerWorker(ctx, packets) + module.mgr.Go("ebpf connection listener", func(wc *mgr.WorkerCtx) error { + return conn_listener.ConnectionListenerWorker(wc.Ctx(), packets) }) // Start ebpf bandwidth stats monitor. - module.StartServiceWorker("ebpf bandwidth stats monitor", 0, func(ctx context.Context) error { - return bandwidth.BandwidthStatsWorker(ctx, 1*time.Second, BandwidthUpdates) + module.mgr.Go("ebpf bandwidth stats monitor", func(wc *mgr.WorkerCtx) error { + return bandwidth.BandwidthStatsWorker(wc.Ctx(), 1*time.Second, BandwidthUpdates) }) return nil diff --git a/service/firewall/interception/module.go b/service/firewall/interception/module.go index eaa013762..189d1bfcb 100644 --- a/service/firewall/interception/module.go +++ b/service/firewall/interception/module.go @@ -1,22 +1,31 @@ package interception import ( + "errors" "flag" "sync/atomic" "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/network/packet" ) type Interception struct { + mgr *mgr.Manager instance instance } -var ( - module *Interception - shimLoaded atomic.Bool +func (i *Interception) Start(m *mgr.Manager) error { + i.mgr = m + return start() +} + +func (i *Interception) Stop(m *mgr.Manager) error { + return stop() +} - // Packets is a stream of interception network packest. +var ( + // Packets is a stream of interception network packets. Packets = make(chan packet.Packet, 1000) // BandwidthUpdates is a stream of bandwidth usage update for connections. @@ -31,10 +40,6 @@ func init() { // module = modules.Register("interception", prep, start, stop, "base", "updates", "network", "notifications", "profiles") } -func prep() error { - return nil -} - // Start starts the interception. func start() error { if disableInterception { @@ -67,4 +72,21 @@ func stop() error { return stopInterception() } +var ( + module *Interception + shimLoaded atomic.Bool +) + +// New returns a new Interception module. +func New(instance instance) (*Interception, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &Interception{ + instance: instance, + } + return module, nil +} + type instance interface{} diff --git a/service/firewall/interception/nfqueue_linux.go b/service/firewall/interception/nfqueue_linux.go index e4c136d9d..cbaad7cce 100644 --- a/service/firewall/interception/nfqueue_linux.go +++ b/service/firewall/interception/nfqueue_linux.go @@ -1,7 +1,6 @@ package interception import ( - "context" "flag" "fmt" "sort" @@ -12,6 +11,7 @@ import ( "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/firewall/interception/nfq" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/network/packet" ) @@ -290,7 +290,7 @@ func StartNfqueueInterception(packets chan<- packet.Packet) (err error) { in6Queue = &disabledNfQueue{} } - module.StartServiceWorker("nfqueue packet handler", 0, func(_ context.Context) error { + module.mgr.Go("nfqueue packet handler", func(_ *mgr.WorkerCtx) error { return handleInterception(packets) }) return nil diff --git a/service/firewall/module.go b/service/firewall/module.go index 71c9a8307..7439ce594 100644 --- a/service/firewall/module.go +++ b/service/firewall/module.go @@ -1,7 +1,6 @@ package firewall import ( - "context" "flag" "fmt" "path/filepath" @@ -12,6 +11,7 @@ import ( "github.com/safing/portmaster/base/log" _ "github.com/safing/portmaster/service/core" "github.com/safing/portmaster/service/mgr" + "github.com/safing/portmaster/service/netquery" "github.com/safing/portmaster/service/network" "github.com/safing/portmaster/service/profile" "github.com/safing/portmaster/spn/access" @@ -81,51 +81,23 @@ func prep() error { // Reset connections when spn is connected // connect and disconnecting is triggered on config change event but connecting takеs more time - err = module.RegisterEventHook( - "captain", - captain.SPNConnectedEvent, - "reset connection verdicts on SPN connect", - func(ctx context.Context, _ interface{}) error { - resetAllConnectionVerdicts() - return nil - }, - ) - if err != nil { - log.Errorf("filter: failed to register event hook: %s", err) - } + module.instance.Captain().EventSPNConnected.AddCallback("reset connection verdicts on SPN connect", func(wc *mgr.WorkerCtx, s struct{}) (cancel bool, err error) { + resetAllConnectionVerdicts() + return false, err + }) // Reset connections when account is updated. // This will not change verdicts, but will update the feature flags on connections. - err = module.RegisterEventHook( - "access", - access.AccountUpdateEvent, - "update connection feature flags after account update", - func(ctx context.Context, _ interface{}) error { - resetAllConnectionVerdicts() - return nil - }, - ) - if err != nil { - log.Errorf("filter: failed to register event hook: %s", err) - } + module.instance.Access().EventAccountUpdate.AddCallback("update connection feature flags after account update", func(wc *mgr.WorkerCtx, s struct{}) (cancel bool, err error) { + resetAllConnectionVerdicts() + return false, err + }) - err = module.RegisterEventHook( - "network", - network.ConnectionReattributedEvent, - "reset verdict of re-attributed connection", - func(ctx context.Context, eventData interface{}) error { - // Expected event data: connection ID. - connID, ok := eventData.(string) - if !ok { - return fmt.Errorf("event data is not a string: %v", eventData) - } - resetSingleConnectionVerdict(connID) - return nil - }, - ) - if err != nil { - log.Errorf("filter: failed to register event hook: %s", err) - } + module.instance.Network().EventConnectionReattributed.AddCallback("reset connection verdicts after connection re-attribution", func(wc *mgr.WorkerCtx, connID string) (cancel bool, err error) { + // Expected event data: connection ID. + resetSingleConnectionVerdict(connID) + return false, err + }) if err := registerConfig(); err != nil { return err @@ -169,4 +141,8 @@ func New(instance instance) (*Filter, error) { type instance interface { Config() *config.Config Profile() *profile.ProfileModule + Captain() *captain.Captain + Access() *access.Access + Network() *network.Network + NetQuery() *netquery.NetQuery } diff --git a/service/firewall/packet_handler.go b/service/firewall/packet_handler.go index 9e48d2468..a290182f8 100644 --- a/service/firewall/packet_handler.go +++ b/service/firewall/packet_handler.go @@ -19,7 +19,6 @@ import ( "github.com/safing/portmaster/service/firewall/interception" "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" - "github.com/safing/portmaster/service/netquery" "github.com/safing/portmaster/service/network" "github.com/safing/portmaster/service/network/netutils" "github.com/safing/portmaster/service/network/packet" @@ -794,8 +793,8 @@ func updateBandwidth(ctx context.Context, bwUpdate *packet.BandwidthUpdate) { } // Update bandwidth in the netquery module. - if netquery.DefaultModule != nil && conn.BandwidthEnabled { - if err := netquery.DefaultModule.Store.UpdateBandwidth( + if module.instance.NetQuery() != nil && conn.BandwidthEnabled { + if err := module.instance.NetQuery().Store.UpdateBandwidth( ctx, conn.HistoryEnabled, fmt.Sprintf("%s/%s", conn.ProcessContext.Source, conn.ProcessContext.Profile), diff --git a/service/instance.go b/service/instance.go index 5fd89a63b..78eb8c79e 100644 --- a/service/instance.go +++ b/service/instance.go @@ -11,7 +11,12 @@ import ( "github.com/safing/portmaster/base/runtime" "github.com/safing/portmaster/service/broadcasts" "github.com/safing/portmaster/service/compat" + "github.com/safing/portmaster/service/core" + "github.com/safing/portmaster/service/core/base" "github.com/safing/portmaster/service/firewall" + "github.com/safing/portmaster/service/firewall/interception" + "github.com/safing/portmaster/service/intel/customlists" + "github.com/safing/portmaster/service/intel/geoip" "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/nameserver" "github.com/safing/portmaster/service/netenv" @@ -48,6 +53,7 @@ type Instance struct { runtime *runtime.Runtime notifications *notifications.Notifications rng *rng.Rng + base *base.Base access *access.Access cabin *cabin.Cabin @@ -60,20 +66,24 @@ type Instance struct { sluice *sluice.SluiceModule terminal *terminal.TerminalModule - updates *updates.Updates - ui *ui.UI - profile *profile.ProfileModule - filter *firewall.Filter - netenv *netenv.NetEnv - status *status.Status - broadcasts *broadcasts.Broadcasts - compat *compat.Compat - nameserver *nameserver.NameServer - netquery *netquery.NetQuery - network *network.Network - process *process.ProcessModule - resolver *resolver.ResolverModule - sync *sync.Sync + updates *updates.Updates + ui *ui.UI + profile *profile.ProfileModule + filter *firewall.Filter + interception *interception.Interception + customlist *customlists.CustomList + geoip *geoip.GeoIP + netenv *netenv.NetEnv + status *status.Status + broadcasts *broadcasts.Broadcasts + compat *compat.Compat + nameserver *nameserver.NameServer + netquery *netquery.NetQuery + network *network.Network + process *process.ProcessModule + resolver *resolver.ResolverModule + sync *sync.Sync + core *core.Core } // New returns a new portmaster service instance. @@ -110,6 +120,10 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { if err != nil { return nil, fmt.Errorf("create rng module: %w", err) } + instance.base, err = base.New(instance) + if err != nil { + return nil, fmt.Errorf("create base module: %w", err) + } // SPN modules instance.access, err = access.New(instance) @@ -120,7 +134,7 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { if err != nil { return nil, fmt.Errorf("create cabin module: %w", err) } - instance.captain, err = captain.New(instance) + instance.captain, err = captain.New(instance, svcCfg.ShutdownFunc) if err != nil { return nil, fmt.Errorf("create captain module: %w", err) } @@ -170,6 +184,18 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { if err != nil { return nil, fmt.Errorf("create filter module: %w", err) } + instance.interception, err = interception.New(instance) + if err != nil { + return nil, fmt.Errorf("create interception module: %w", err) + } + instance.customlist, err = customlists.New(instance) + if err != nil { + return nil, fmt.Errorf("create customlist module: %w", err) + } + instance.geoip, err = geoip.New(instance) + if err != nil { + return nil, fmt.Errorf("create customlist module: %w", err) + } instance.netenv, err = netenv.New(instance) if err != nil { return nil, fmt.Errorf("create netenv module: %w", err) @@ -210,6 +236,10 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { if err != nil { return nil, fmt.Errorf("create sync module: %w", err) } + instance.core, err = core.New(instance) + if err != nil { + return nil, fmt.Errorf("create core module: %w", err) + } // Add all modules to instance group. instance.Group = mgr.NewGroup( @@ -219,6 +249,7 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { instance.runtime, instance.notifications, instance.rng, + instance.base, instance.access, instance.cabin, @@ -235,6 +266,9 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { instance.ui, instance.profile, instance.filter, + instance.interception, + instance.customlist, + instance.geoip, instance.netenv, instance.status, instance.broadcasts, @@ -245,11 +279,21 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { instance.process, instance.resolver, instance.sync, + instance.core, ) + // FIXME: call this before to trigger shutdown/restart event + // core.ShutdownHook() + return instance, nil } +func (i *Instance) SetSleep(enabled bool) { + i.metrics.SetSleep(enabled) + i.network.SetSleep(enabled) + i.captain.SetSleep(enabled) +} + // Version returns the version. func (i *Instance) Version() string { return i.version @@ -280,6 +324,11 @@ func (i *Instance) Rng() *rng.Rng { return i.rng } +// Base returns the base module. +func (i *Instance) Base() *base.Base { + return i.base +} + // Access returns the access module. func (i *Instance) Access() *access.Access { return i.access @@ -330,6 +379,11 @@ func (i *Instance) Terminal() *terminal.TerminalModule { return i.terminal } +// Updates returns the updates module. +func (i *Instance) Updates() *updates.Updates { + return i.updates +} + // UI returns the ui module. func (i *Instance) UI() *ui.UI { return i.ui @@ -345,11 +399,21 @@ func (i *Instance) Profile() *profile.ProfileModule { return i.profile } -// Profile returns the profile module. +// Firewall returns the firewall module. func (i *Instance) Firewall() *firewall.Filter { return i.filter } +// Interception returns the interception module. +func (i *Instance) Interception() *interception.Interception { + return i.interception +} + +// CustomList returns the customlist module. +func (i *Instance) CustomList() *customlists.CustomList { + return i.customlist +} + // NetEnv returns the netenv module. func (i *Instance) NetEnv() *netenv.NetEnv { return i.netenv @@ -399,3 +463,14 @@ func (i *Instance) Resolver() *resolver.ResolverModule { func (i *Instance) Sync() *sync.Sync { return i.sync } + +// Core returns the core module. +func (i *Instance) Core() *core.Core { + return i.core +} + +// Events +// SPN connected +func (i *Instance) GetEventSPNConnected() *mgr.EventMgr[struct{}] { + return i.captain.EventSPNConnected +} diff --git a/service/intel/customlists/lists.go b/service/intel/customlists/lists.go index cf807248a..66935991e 100644 --- a/service/intel/customlists/lists.go +++ b/service/intel/customlists/lists.go @@ -12,6 +12,7 @@ import ( "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/notifications" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/network/netutils" ) @@ -79,7 +80,12 @@ func parseFile(filePath string) error { file, err := os.Open(filePath) if err != nil { log.Warningf("intel/customlists: failed to parse file %s", err) - module.Warning(parseWarningNotificationID, "Failed to open custom filter list", err.Error()) + module.States.Add(mgr.State{ + ID: parseWarningNotificationID, + Name: "Failed to open custom filter list", + Message: err.Error(), + Type: mgr.StateTypeWarning, + }) return err } defer func() { _ = file.Close() }() @@ -107,11 +113,15 @@ func parseFile(filePath string) error { if invalidLinesRation > rationForInvalidLinesUntilWarning { log.Warning("intel/customlists: Too many invalid lines") - module.Warning(zeroIPNotificationID, "Custom filter list has many invalid lines", - fmt.Sprintf(`%d out of %d lines are invalid. - Check if you are using the correct file format and if the path to the custom filter list is correct.`, invalidLinesCount, allLinesCount)) + module.States.Add(mgr.State{ + ID: zeroIPNotificationID, + Name: "Custom filter list has many invalid lines", + Message: fmt.Sprintf(`%d out of %d lines are invalid. + Check if you are using the correct file format and if the path to the custom filter list is correct.`, invalidLinesCount, allLinesCount), + Type: mgr.StateTypeWarning, + }) } else { - module.Resolve(zeroIPNotificationID) + module.States.Remove(zeroIPNotificationID) } allEntriesCount := len(domainsFilterList) + len(ipAddressesFilterList) + len(autonomousSystemsFilterList) + len(countryCodesFilterList) @@ -130,7 +140,7 @@ func parseFile(filePath string) error { len(autonomousSystemsFilterList), len(countryCodesFilterList))) - module.Resolve(parseWarningNotificationID) + module.States.Remove(parseWarningNotificationID) return nil } diff --git a/service/intel/customlists/module.go b/service/intel/customlists/module.go index 988617037..c7c59edf9 100644 --- a/service/intel/customlists/module.go +++ b/service/intel/customlists/module.go @@ -1,27 +1,42 @@ package customlists import ( - "context" "errors" "net" "os" "regexp" "strings" "sync" + "sync/atomic" "time" "golang.org/x/net/publicsuffix" "github.com/safing/portmaster/base/api" - "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/service/mgr" ) -var module *modules.Module +type CustomList struct { + mgr *mgr.Manager + instance instance -const ( - configModuleName = "config" - configChangeEvent = "config change" -) + States *mgr.StateMgr +} + +func (cl *CustomList) Start(m *mgr.Manager) error { + cl.mgr = m + cl.States = mgr.NewStateMgr(m) + + if err := prep(); err != nil { + return err + } + return start() +} + +func (cl *CustomList) Stop(m *mgr.Manager) error { + return nil +} // Helper variables for parsing the input file. var ( @@ -34,17 +49,12 @@ var ( filterListFileModifiedTime time.Time filterListLock sync.RWMutex - parserTask *modules.Task // ErrNotConfigured is returned when updating the custom filter list, but it // is not configured. ErrNotConfigured = errors.New("custom filter list not configured") ) -func init() { - module = modules.Register("customlists", prep, start, nil, "base") -} - func prep() error { initFilterLists() @@ -56,9 +66,8 @@ func prep() error { // Register api endpoint for updating the filter list. if err := api.RegisterEndpoint(api.Endpoint{ - Path: "customlists/update", - Write: api.PermitUser, - BelongsTo: module, + Path: "customlists/update", + Write: api.PermitUser, ActionFunc: func(ar *api.Request) (msg string, err error) { errCheck := checkAndUpdateFilterList() if errCheck != nil { @@ -77,25 +86,21 @@ func prep() error { func start() error { // Register to hook to update after config change. - if err := module.RegisterEventHook( - configModuleName, - configChangeEvent, + module.instance.Config().EventConfigChange.AddCallback( "update custom filter list", - func(ctx context.Context, obj interface{}) error { + func(_ *mgr.WorkerCtx, _ struct{}) (bool, error) { if err := checkAndUpdateFilterList(); !errors.Is(err, ErrNotConfigured) { - return err + return false, err } - return nil + return false, nil }, - ); err != nil { - return err - } + ) // Create parser task and enqueue for execution. "checkAndUpdateFilterList" will schedule the next execution. - parserTask = module.NewTask("intel/customlists:file-update-check", func(context.Context, *modules.Task) error { + module.mgr.Repeat("intel/customlists:file-update-check", 20*time.Second, func(_ *mgr.WorkerCtx) error { _ = checkAndUpdateFilterList() return nil - }).Schedule(time.Now().Add(20 * time.Second)) + }) return nil } @@ -111,7 +116,8 @@ func checkAndUpdateFilterList() error { } // Schedule next update check - parserTask.Schedule(time.Now().Add(1 * time.Minute)) + // TODO(vladimir): The task is set to repeate evry few seconds does. Is there another way to make it better? + // parserTask.Schedule(time.Now().Add(1 * time.Minute)) // Try to get file info modifiedTime := time.Now() @@ -205,3 +211,24 @@ func splitDomain(domain string) []string { } return domains } + +var ( + module *CustomList + shimLoaded atomic.Bool +) + +// New returns a new CustomList module. +func New(instance instance) (*CustomList, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &CustomList{ + instance: instance, + } + return module, nil +} + +type instance interface { + Config() *config.Config +} diff --git a/service/intel/filterlists/module.go b/service/intel/filterlists/module.go index 0fa6aadf8..376402f90 100644 --- a/service/intel/filterlists/module.go +++ b/service/intel/filterlists/module.go @@ -1,25 +1,45 @@ package filterlists import ( - "context" + "errors" "fmt" + "sync/atomic" "github.com/tevino/abool" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/updates" ) -var module *modules.Module - const ( filterlistsDisabled = "filterlists:disabled" filterlistsUpdateFailed = "filterlists:update-failed" filterlistsStaleDataSurvived = "filterlists:staledata" ) +type FilterLists struct { + mgr *mgr.Manager + instance instance + + States *mgr.StateMgr +} + +func (fl *FilterLists) Start(m *mgr.Manager) error { + fl.mgr = m + fl.States = mgr.NewStateMgr(m) + + if err := prep(); err != nil { + return err + } + return start() +} + +func (fl *FilterLists) Stop(m *mgr.Manager) error { + return stop() +} + // booleans mainly used to decouple the module // during testing. var ( @@ -30,44 +50,31 @@ var ( func init() { ignoreNetEnvEvents.Set() - module = modules.Register("filterlists", prep, start, stop, "base", "updates") + // module = modules.Register("filterlists", prep, start, stop, "base", "updates") } func prep() error { - if err := module.RegisterEventHook( - updates.ModuleName, - updates.ResourceUpdateEvent, - "Check for blocklist updates", - func(ctx context.Context, _ interface{}) error { + module.instance.Updates().EventResourcesUpdated.AddCallback("Check for blocklist updates", + func(wc *mgr.WorkerCtx, s struct{}) (bool, error) { if ignoreUpdateEvents.IsSet() { - return nil + return false, nil } - return tryListUpdate(ctx) - }, - ); err != nil { - return fmt.Errorf("failed to register resource update event handler: %w", err) - } + return false, tryListUpdate(wc.Ctx()) + }) - if err := module.RegisterEventHook( - netenv.ModuleName, - netenv.OnlineStatusChangedEvent, - "Check for blocklist updates", - func(ctx context.Context, _ interface{}) error { + module.instance.NetEnv().EventOnlineStatusChange.AddCallback("Check for blocklist updates", + func(wc *mgr.WorkerCtx, s netenv.OnlineStatus) (bool, error) { if ignoreNetEnvEvents.IsSet() { - return nil + return false, nil } - // Nothing to do if we went offline. - if !netenv.Online() { - return nil + if s == netenv.StatusOffline { + return false, nil } - return tryListUpdate(ctx) - }, - ); err != nil { - return fmt.Errorf("failed to register online status changed event handler: %w", err) - } + return false, tryListUpdate(wc.Ctx()) + }) return nil } @@ -102,9 +109,32 @@ func stop() error { } func warnAboutDisabledFilterLists() { - module.Warning( - filterlistsDisabled, - "Filter Lists Are Initializing", - "Filter lists are being downloaded and set up in the background. They will be activated as configured when finished.", - ) + module.States.Add(mgr.State{ + ID: filterlistsDisabled, + Name: "Filter Lists Are Initializing", + Message: "Filter lists are being downloaded and set up in the background. They will be activated as configured when finished.", + Type: mgr.StateTypeWarning, + }) +} + +var ( + module *FilterLists + shimLoaded atomic.Bool +) + +// New returns a new FilterLists module. +func New(instance instance) (*FilterLists, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &FilterLists{ + instance: instance, + } + return module, nil +} + +type instance interface { + Updates() *updates.Updates + NetEnv() *netenv.NetEnv } diff --git a/service/intel/filterlists/updater.go b/service/intel/filterlists/updater.go index cbe9fc3f8..8d3b19237 100644 --- a/service/intel/filterlists/updater.go +++ b/service/intel/filterlists/updater.go @@ -13,8 +13,8 @@ import ( "github.com/safing/portmaster/base/database" "github.com/safing/portmaster/base/database/query" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/base/updater" + "github.com/safing/portmaster/service/mgr" ) var updateInProgress = abool.New() @@ -25,19 +25,27 @@ func tryListUpdate(ctx context.Context) error { err := performUpdate(ctx) if err != nil { // Check if we are shutting down. - if module.IsStopping() { - return nil - } + // TODO(vladimir): Do we need stopping detection? + // if module.IsStopping() { + // return nil + // } // Check if the module already has a failure status set. If not, set a // generic one with the returned error. - failureStatus, _, _ := module.FailureStatus() - if failureStatus < modules.FailureWarning { - module.Warning( - filterlistsUpdateFailed, - "Filter Lists Update Failed", - fmt.Sprintf("The Portmaster failed to process a filter lists update. Filtering capabilities are currently either impaired or not available at all. Error: %s", err.Error()), - ) + + hasWarningState := false + for _, state := range module.States.Export().States { + if state.Type == mgr.StateTypeWarning { + hasWarningState = true + } + } + if !hasWarningState { + module.States.Add(mgr.State{ + ID: filterlistsUpdateFailed, + Name: "Filter Lists Update Failed", + Message: fmt.Sprintf("The Portmaster failed to process a filter lists update. Filtering capabilities are currently either impaired or not available at all. Error: %s", err.Error()), + Type: mgr.StateTypeWarning, + }) } return err @@ -122,15 +130,16 @@ func performUpdate(ctx context.Context) error { // been updated now. Once we are done, start a worker // for that purpose. if cleanupRequired { - if err := module.RunWorker("filterlists:cleanup", removeAllObsoleteFilterEntries); err != nil { + if err := module.mgr.Do("filterlists:cleanup", removeAllObsoleteFilterEntries); err != nil { // if we failed to remove all stale cache entries // we abort now WITHOUT updating the database version. This means // we'll try again during the next update. - module.Warning( - filterlistsStaleDataSurvived, - "Filter Lists May Overblock", - fmt.Sprintf("The Portmaster failed to delete outdated filter list data. Filtering capabilities are fully available, but overblocking may occur. Error: %s", err.Error()), //nolint:misspell // overblocking != overclocking - ) + module.States.Add(mgr.State{ + ID: filterlistsStaleDataSurvived, + Name: "Filter Lists May Overblock", + Message: fmt.Sprintf("The Portmaster failed to delete outdated filter list data. Filtering capabilities are fully available, but overblocking may occur. Error: %s", err.Error()), //nolint:misspell // overblocking != overclocking + Type: mgr.StateTypeWarning, + }) return fmt.Errorf("failed to cleanup stale cache records: %w", err) } } @@ -144,13 +153,13 @@ func performUpdate(ctx context.Context) error { } // The list update succeeded, resolve any states. - module.Resolve("") + module.States.Clear() return nil } -func removeAllObsoleteFilterEntries(ctx context.Context) error { +func removeAllObsoleteFilterEntries(wc *mgr.WorkerCtx) error { log.Debugf("intel/filterlists: cleanup task started, removing obsolete filter list entries ...") - n, err := cache.Purge(ctx, query.New(filterListKeyPrefix).Where( + n, err := cache.Purge(wc.Ctx(), query.New(filterListKeyPrefix).Where( // TODO(ppacher): remember the timestamp we started the last update // and use that rather than "one hour ago" query.Where("UpdatedAt", query.LessThan, time.Now().Add(-time.Hour).Unix()), diff --git a/service/intel/geoip/database.go b/service/intel/geoip/database.go index 3101f7dc0..6aee3d944 100644 --- a/service/intel/geoip/database.go +++ b/service/intel/geoip/database.go @@ -1,7 +1,6 @@ package geoip import ( - "context" "fmt" "sync" "time" @@ -10,6 +9,7 @@ import ( "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/updater" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/updates" ) @@ -148,11 +148,11 @@ func (upd *updateWorker) triggerUpdate() { func (upd *updateWorker) start() { upd.once.Do(func() { - module.StartServiceWorker("geoip-updater", time.Second*10, upd.run) + module.mgr.Delay("geoip-updater", time.Second*10, upd.run) }) } -func (upd *updateWorker) run(ctx context.Context) error { +func (upd *updateWorker) run(ctx *mgr.WorkerCtx) error { for { if upd.v4.NeedsUpdate() { if v4, err := getGeoIPDB(v4MMDBResource); err == nil { diff --git a/service/intel/geoip/module.go b/service/intel/geoip/module.go index 7141d476e..326352ecc 100644 --- a/service/intel/geoip/module.go +++ b/service/intel/geoip/module.go @@ -1,20 +1,21 @@ package geoip import ( - "context" + "errors" + "sync/atomic" "github.com/safing/portmaster/base/api" - "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/updates" ) -var module *modules.Module - -func init() { - module = modules.Register("geoip", prep, nil, nil, "base", "updates") +type GeoIP struct { + mgr *mgr.Manager + instance instance } -func prep() error { +func (g *GeoIP) Start(m *mgr.Manager) error { + g.mgr = m if err := api.RegisterEndpoint(api.Endpoint{ Path: "intel/geoip/countries", Read: api.PermitUser, @@ -28,13 +29,36 @@ func prep() error { return err } - return module.RegisterEventHook( - updates.ModuleName, - updates.ResourceUpdateEvent, + module.instance.Updates().EventResourcesUpdated.AddCallback( "Check for GeoIP database updates", - func(c context.Context, i interface{}) error { + func(_ *mgr.WorkerCtx, _ struct{}) (bool, error) { worker.triggerUpdate() - return nil - }, - ) + return false, nil + }) + return nil +} + +func (g *GeoIP) Stop(m *mgr.Manager) error { + return nil +} + +var ( + module *GeoIP + shimLoaded atomic.Bool +) + +// New returns a new GeoIP module. +func New(instance instance) (*GeoIP, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + module = &GeoIP{ + instance: instance, + } + return module, nil +} + +type instance interface { + Updates() *updates.Updates } diff --git a/service/intel/module.go b/service/intel/module.go index 87f872b51..41a7386aa 100644 --- a/service/intel/module.go +++ b/service/intel/module.go @@ -1,13 +1,12 @@ package intel -import ( - "github.com/safing/portmaster/base/modules" - _ "github.com/safing/portmaster/service/intel/customlists" -) +// import ( +// _ "github.com/safing/portmaster/service/intel/customlists" +// ) // Module of this package. Export needed for testing of the endpoints package. -var Module *modules.Module +// var Module *modules.Module -func init() { - Module = modules.Register("intel", nil, nil, nil, "geoip", "filterlists", "customlists") -} +// func init() { +// Module = modules.Register("intel", nil, nil, nil, "geoip", "filterlists", "customlists") +// } diff --git a/service/mgr/sleepyticker.go b/service/mgr/sleepyticker.go new file mode 100644 index 000000000..075912a16 --- /dev/null +++ b/service/mgr/sleepyticker.go @@ -0,0 +1,58 @@ +package mgr + +import "time" + +// SleepyTicker is wrapper over time.Ticker that respects the sleep mode of the module. +type SleepyTicker struct { + ticker time.Ticker + normalDuration time.Duration + sleepDuration time.Duration + sleepMode bool + + sleepChannel chan time.Time +} + +// NewSleepyTicker returns a new SleepyTicker. This is a wrapper of the standard time.Ticker but it respects modules.Module sleep mode. Check https://pkg.go.dev/time#Ticker. +// If sleepDuration is set to 0 ticker will not tick during sleep. +func NewSleepyTicker(normalDuration time.Duration, sleepDuration time.Duration) *SleepyTicker { + st := &SleepyTicker{ + ticker: *time.NewTicker(normalDuration), + normalDuration: normalDuration, + sleepDuration: sleepDuration, + sleepMode: false, + } + + return st +} + +// Wait waits until the module is not in sleep mode and returns time.Ticker.C channel. +func (st *SleepyTicker) Wait() <-chan time.Time { + if st.sleepMode && st.sleepDuration == 0 { + return st.sleepChannel + } + return st.ticker.C +} + +// Stop turns off a ticker. After Stop, no more ticks will be sent. Stop does not close the channel, to prevent a concurrent goroutine reading from the channel from seeing an erroneous "tick". +func (st *SleepyTicker) Stop() { + st.ticker.Stop() +} + +// SetSleep sets the sleep mode of the ticker. If enabled is true, the ticker will tick with sleepDuration. If enabled is false, the ticker will tick with normalDuration. +func (st *SleepyTicker) SetSleep(enabled bool) { + st.sleepMode = enabled + if enabled { + if st.sleepDuration > 0 { + st.ticker.Reset(st.sleepDuration) + } else { + // Next call to Wait will wait until SetSleep is called with enabled == false + st.sleepChannel = make(chan time.Time) + } + } else { + st.ticker.Reset(st.normalDuration) + if st.sleepDuration > 0 { + // Notify that we are not sleeping anymore. + close(st.sleepChannel) + } + } +} diff --git a/service/nameserver/module.go b/service/nameserver/module.go index c8d0afd08..1107c85d3 100644 --- a/service/nameserver/module.go +++ b/service/nameserver/module.go @@ -22,10 +22,13 @@ import ( type NameServer struct { mgr *mgr.Manager instance instance + + States *mgr.StateMgr } func (ns *NameServer) Start(m *mgr.Manager) error { ns.mgr = m + ns.States = mgr.NewStateMgr(m) if err := prep(); err != nil { return err } @@ -154,7 +157,7 @@ func startListener(ip net.IP, port uint16, first bool) { // Resolve generic listener error, if primary listener. if first { - module.Resolve(eventIDListenerFailed) + module.States.Remove(eventIDListenerFailed) } // Start listening. @@ -162,7 +165,7 @@ func startListener(ip net.IP, port uint16, first bool) { err := dnsServer.ListenAndServe() if err != nil { // Stop worker without error if we are shutting down. - if module.IsStopping() { + if module.mgr.IsDone() { return nil } log.Warningf("nameserver: failed to listen on %s: %s", dnsServer.Addr, err) @@ -173,7 +176,7 @@ func startListener(ip net.IP, port uint16, first bool) { } func handleListenError(err error, ip net.IP, port uint16, primaryListener bool) { - var n *notifications.Notification + // var n *notifications.Notification // Create suffix for secondary listener var secondaryEventIDSuffix string @@ -202,7 +205,7 @@ func handleListenError(err error, ip net.IP, port uint16, primaryListener bool) } // Notify user about conflicting service. - n = notifications.Notify(¬ifications.Notification{ + _ = notifications.Notify(¬ifications.Notification{ EventID: eventIDConflictingService + secondaryEventIDSuffix, Type: notifications.Error, Title: "Conflicting DNS Software", @@ -219,7 +222,7 @@ func handleListenError(err error, ip net.IP, port uint16, primaryListener bool) }) } else { // If no conflict is found, report the error directly. - n = notifications.Notify(¬ifications.Notification{ + _ = notifications.Notify(¬ifications.Notification{ EventID: eventIDListenerFailed + secondaryEventIDSuffix, Type: notifications.Error, Title: "Secure DNS Error", @@ -232,9 +235,10 @@ func handleListenError(err error, ip net.IP, port uint16, primaryListener bool) } // Attach error to module, if primary listener. - if primaryListener { - n.AttachToModule(module) - } + // TODO(vladimir): is this needed? + // if primaryListener { + // n.AttachToModule(module) + // } } func stop() error { diff --git a/service/nameserver/nameserver.go b/service/nameserver/nameserver.go index 903541f5f..c699cd993 100644 --- a/service/nameserver/nameserver.go +++ b/service/nameserver/nameserver.go @@ -12,6 +12,7 @@ import ( "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/firewall" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/nameserver/nsutil" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/network" @@ -24,8 +25,8 @@ var hostname string const internalError = "internal error: " func handleRequestAsWorker(w dns.ResponseWriter, query *dns.Msg) { - err := module.RunWorker("handle dns request", func(ctx context.Context) error { - return handleRequest(ctx, w, query) + err := module.mgr.Do("handle dns request", func(wc *mgr.WorkerCtx) error { + return handleRequest(wc.Ctx(), w, query) }) if err != nil { log.Warningf("nameserver: failed to handle dns request: %s", err) diff --git a/service/netenv/api.go b/service/netenv/api.go index e5dca1870..358bee3fa 100644 --- a/service/netenv/api.go +++ b/service/netenv/api.go @@ -8,9 +8,8 @@ import ( func registerAPIEndpoints() error { if err := api.RegisterEndpoint(api.Endpoint{ - Path: "network/gateways", - Read: api.PermitUser, - BelongsTo: module, + Path: "network/gateways", + Read: api.PermitUser, StructFunc: func(ar *api.Request) (i interface{}, err error) { return Gateways(), nil }, @@ -21,9 +20,8 @@ func registerAPIEndpoints() error { } if err := api.RegisterEndpoint(api.Endpoint{ - Path: "network/nameservers", - Read: api.PermitUser, - BelongsTo: module, + Path: "network/nameservers", + Read: api.PermitUser, StructFunc: func(ar *api.Request) (i interface{}, err error) { return Nameservers(), nil }, @@ -34,9 +32,8 @@ func registerAPIEndpoints() error { } if err := api.RegisterEndpoint(api.Endpoint{ - Path: "network/location", - Read: api.PermitUser, - BelongsTo: module, + Path: "network/location", + Read: api.PermitUser, StructFunc: func(ar *api.Request) (i interface{}, err error) { locs, ok := GetInternetLocation() if ok { @@ -51,9 +48,8 @@ func registerAPIEndpoints() error { } if err := api.RegisterEndpoint(api.Endpoint{ - Path: "network/location/traceroute", - Read: api.PermitUser, - BelongsTo: module, + Path: "network/location/traceroute", + Read: api.PermitUser, StructFunc: func(ar *api.Request) (i interface{}, err error) { return getLocationFromTraceroute(&DeviceLocations{}) }, diff --git a/service/netquery/module_api.go b/service/netquery/module_api.go index 2a4a2d008..831b72bee 100644 --- a/service/netquery/module_api.go +++ b/service/netquery/module_api.go @@ -103,7 +103,6 @@ func (nq *NetQuery) prepare() error { MimeType: "application/json", Read: api.PermitUser, // Needs read+write as the query is sent using POST data. Write: api.PermitUser, // Needs read+write as the query is sent using POST data. - BelongsTo: m.Module, HandlerFunc: servertiming.Middleware(queryHander, nil).ServeHTTP, }); err != nil { return fmt.Errorf("failed to register API endpoint: %w", err) @@ -116,7 +115,6 @@ func (nq *NetQuery) prepare() error { MimeType: "application/json", Read: api.PermitUser, // Needs read+write as the query is sent using POST data. Write: api.PermitUser, // Needs read+write as the query is sent using POST data. - BelongsTo: m.Module, HandlerFunc: servertiming.Middleware(batchHander, nil).ServeHTTP, }); err != nil { return fmt.Errorf("failed to register API endpoint: %w", err) @@ -128,7 +126,6 @@ func (nq *NetQuery) prepare() error { Path: "netquery/charts/connection-active", MimeType: "application/json", Write: api.PermitUser, - BelongsTo: m.Module, HandlerFunc: servertiming.Middleware(chartHandler, nil).ServeHTTP, }); err != nil { return fmt.Errorf("failed to register API endpoint: %w", err) @@ -139,7 +136,6 @@ func (nq *NetQuery) prepare() error { Path: "netquery/charts/bandwidth", MimeType: "application/json", Write: api.PermitUser, - BelongsTo: m.Module, HandlerFunc: bwChartHandler.ServeHTTP, Name: "Bandwidth Chart", Description: "Query the in-memory sqlite connection database and return a chart of bytes sent/received.", @@ -153,7 +149,6 @@ func (nq *NetQuery) prepare() error { Path: "netquery/history/clear", MimeType: "application/json", Write: api.PermitUser, - BelongsTo: m.Module, ActionFunc: func(ar *api.Request) (msg string, err error) { var body struct { ProfileIDs []string `json:"profileIDs"` @@ -188,10 +183,9 @@ func (nq *NetQuery) prepare() error { } if err := api.RegisterEndpoint(api.Endpoint{ - Name: "Apply connection history retention threshold", - Path: "netquery/history/cleanup", - Write: api.PermitUser, - BelongsTo: m.Module, + Name: "Apply connection history retention threshold", + Path: "netquery/history/cleanup", + Write: api.PermitUser, ActionFunc: func(ar *api.Request) (msg string, err error) { if err := nq.Store.CleanupHistory(ar.Context()); err != nil { return "", err diff --git a/service/network/api.go b/service/network/api.go index 0624f806f..78f7f7511 100644 --- a/service/network/api.go +++ b/service/network/api.go @@ -23,7 +23,6 @@ func registerAPIEndpoints() error { if err := api.RegisterEndpoint(api.Endpoint{ Path: "debug/network", Read: api.PermitUser, - BelongsTo: module, DataFunc: debugInfo, Name: "Get Network Debug Information", Description: "Returns network debugging information, similar to debug/core, but with connection data.", @@ -52,9 +51,8 @@ func registerAPIEndpoints() error { } if err := api.RegisterEndpoint(api.Endpoint{ - Path: "debug/network/state", - Read: api.PermitUser, - BelongsTo: module, + Path: "debug/network/state", + Read: api.PermitUser, StructFunc: func(ar *api.Request) (i interface{}, err error) { return state.GetInfo(), nil }, diff --git a/service/network/clean.go b/service/network/clean.go index 62fe04587..962fef332 100644 --- a/service/network/clean.go +++ b/service/network/clean.go @@ -32,14 +32,14 @@ const ( ) func connectionCleaner(ctx *mgr.WorkerCtx) error { - ticker := module.NewSleepyTicker(cleanerTickDuration, 0) + module.connectionCleanerTicker = mgr.NewSleepyTicker(cleanerTickDuration, 0) + defer module.connectionCleanerTicker.Stop() for { select { case <-ctx.Done(): - ticker.Stop() return nil - case <-ticker.Wait(): + case <-module.connectionCleanerTicker.Wait(): // clean connections and processes activePIDs := cleanConnections() process.CleanProcessStorage(activePIDs) diff --git a/service/network/connection.go b/service/network/connection.go index deff0ae94..e1d968b58 100644 --- a/service/network/connection.go +++ b/service/network/connection.go @@ -15,6 +15,7 @@ import ( "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/service/intel" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/network/netutils" "github.com/safing/portmaster/service/network/packet" @@ -833,7 +834,7 @@ func (conn *Connection) SetFirewallHandler(handler FirewallHandler) { // Start packet handler worker when new handler is set. if conn.firewallHandler == nil { - module.StartWorker("packet handler", conn.packetHandlerWorker) + module.mgr.Go("packet handler", conn.packetHandlerWorker) } // Set new handler. @@ -899,7 +900,7 @@ func (conn *Connection) HandlePacket(pkt packet.Packet) { var infoOnlyPacketsActive = abool.New() // packetHandlerWorker sequentially handles queued packets. -func (conn *Connection) packetHandlerWorker(ctx context.Context) error { +func (conn *Connection) packetHandlerWorker(ctx *mgr.WorkerCtx) error { // Copy packet queue, so we can remove the reference from the connection // when we stop the firewall handler. var pktQueue chan packet.Packet @@ -948,7 +949,7 @@ func (conn *Connection) packetHandlerWorker(ctx context.Context) error { if infoPkt != nil { // DEBUG: // log.Debugf("filter: packet #%d [pulled forward] info=%v PID=%d packet: %s", pktSeq, infoPkt.InfoOnly(), infoPkt.Info().PID, pkt) - packetHandlerHandleConn(ctx, conn, infoPkt) + packetHandlerHandleConn(ctx.Ctx(), conn, infoPkt) pktSeq++ } case <-time.After(1 * time.Millisecond): @@ -967,7 +968,7 @@ func (conn *Connection) packetHandlerWorker(ctx context.Context) error { // log.Debugf("filter: packet #%d info=%v PID=%d packet: %s", pktSeq, pkt.InfoOnly(), pkt.Info().PID, pkt) // } - packetHandlerHandleConn(ctx, conn, pkt) + packetHandlerHandleConn(ctx.Ctx(), conn, pkt) case <-ctx.Done(): return nil diff --git a/service/network/database.go b/service/network/database.go index 33fd8038d..7464f36c2 100644 --- a/service/network/database.go +++ b/service/network/database.go @@ -1,7 +1,6 @@ package network import ( - "context" "fmt" "strconv" "strings" @@ -11,6 +10,7 @@ import ( "github.com/safing/portmaster/base/database/query" "github.com/safing/portmaster/base/database/record" "github.com/safing/portmaster/base/database/storage" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/process" ) @@ -110,7 +110,7 @@ func (s *StorageInterface) Get(key string) (record.Record, error) { func (s *StorageInterface) Query(q *query.Query, local, internal bool) (*iterator.Iterator, error) { it := iterator.New() - module.StartWorker("connection query", func(_ context.Context) error { + module.mgr.Go("connection query", func(_ *mgr.WorkerCtx) error { s.processQuery(q, it) return nil }) diff --git a/service/network/dns.go b/service/network/dns.go index c7e02b586..e8cdbb63d 100644 --- a/service/network/dns.go +++ b/service/network/dns.go @@ -175,14 +175,14 @@ func SaveOpenDNSRequest(q *resolver.Query, rrCache *resolver.RRCache, conn *Conn } func openDNSRequestWriter(ctx *mgr.WorkerCtx) error { - ticker := module.NewSleepyTicker(writeOpenDNSRequestsTickDuration, 0) - defer ticker.Stop() + module.dnsRequestTicker = mgr.NewSleepyTicker(writeOpenDNSRequestsTickDuration, 0) + defer module.dnsRequestTicker.Stop() for { select { case <-ctx.Done(): return nil - case <-ticker.Wait(): + case <-module.dnsRequestTicker.Wait(): writeOpenDNSRequestsToDB() } } diff --git a/service/network/module.go b/service/network/module.go index a14e3a7e9..bff26e561 100644 --- a/service/network/module.go +++ b/service/network/module.go @@ -24,6 +24,9 @@ type Network struct { mgr *mgr.Manager instance instance + dnsRequestTicker *mgr.SleepyTicker + connectionCleanerTicker *mgr.SleepyTicker + EventConnectionReattributed *mgr.EventMgr[string] } @@ -41,6 +44,15 @@ func (n *Network) Stop(mgr *mgr.Manager) error { return nil } +func (n *Network) SetSleep(enabled bool) { + if n.dnsRequestTicker != nil { + n.dnsRequestTicker.SetSleep(enabled) + } + if n.connectionCleanerTicker != nil { + n.connectionCleanerTicker.SetSleep(enabled) + } +} + var defaultFirewallHandler FirewallHandler // SetDefaultFirewallHandler sets the default firewall handler. diff --git a/service/process/api.go b/service/process/api.go index a17b3c70d..f4f669233 100644 --- a/service/process/api.go +++ b/service/process/api.go @@ -15,7 +15,6 @@ func registerAPIEndpoints() error { Description: "Get information about process tags.", Path: "process/tags", Read: api.PermitUser, - BelongsTo: module, StructFunc: handleProcessTagMetadata, }); err != nil { return err @@ -26,7 +25,6 @@ func registerAPIEndpoints() error { Description: "Get all recently active processes using the given profile", Path: "process/list/by-profile/{source:[a-z]+}/{id:[A-z0-9-]+}", Read: api.PermitUser, - BelongsTo: module, StructFunc: handleGetProcessesByProfile, }); err != nil { return err @@ -37,7 +35,6 @@ func registerAPIEndpoints() error { Description: "Load a process group leader by a child PID", Path: "process/group-leader/{pid:[0-9]+}", Read: api.PermitUser, - BelongsTo: module, StructFunc: handleGetProcessGroupLeader, }); err != nil { return err diff --git a/service/profile/active.go b/service/profile/active.go index 2ac053e7e..039db01fa 100644 --- a/service/profile/active.go +++ b/service/profile/active.go @@ -1,7 +1,6 @@ package profile import ( - "context" "sync" "time" diff --git a/service/profile/api.go b/service/profile/api.go index 60484e10e..e23a4c851 100644 --- a/service/profile/api.go +++ b/service/profile/api.go @@ -19,7 +19,6 @@ func registerAPIEndpoints() error { Description: "Merge multiple profiles into a new one.", Path: "profile/merge", Write: api.PermitUser, - BelongsTo: module, StructFunc: handleMergeProfiles, }); err != nil { return err @@ -30,7 +29,6 @@ func registerAPIEndpoints() error { Description: "Returns the requested profile icon.", Path: "profile/icon/{id:[a-f0-9]*\\.[a-z]{3,4}}", Read: api.PermitUser, - BelongsTo: module, DataFunc: handleGetProfileIcon, }); err != nil { return err @@ -41,7 +39,6 @@ func registerAPIEndpoints() error { Description: "Updates a profile icon.", Path: "profile/icon", Write: api.PermitUser, - BelongsTo: module, StructFunc: handleUpdateProfileIcon, }); err != nil { return err diff --git a/service/profile/config-update.go b/service/profile/config-update.go index eebf0b94b..b0436be6b 100644 --- a/service/profile/config-update.go +++ b/service/profile/config-update.go @@ -6,8 +6,8 @@ import ( "sync" "time" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/service/intel/filterlists" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/profile/endpoints" ) @@ -24,19 +24,16 @@ var ( ) func registerConfigUpdater() error { - return module.RegisterEventHook( - "config", - "config change", - "update global config profile", - func(ctx context.Context, _ interface{}) error { - return updateGlobalConfigProfile(ctx, nil) - }, - ) + module.instance.Config().EventConfigChange.AddCallback("update global config profile", func(wc *mgr.WorkerCtx, s struct{}) (cancel bool, err error) { + return false, updateGlobalConfigProfile(wc.Ctx()) + }) + + return nil } const globalConfigProfileErrorID = "profile:global-profile-error" -func updateGlobalConfigProfile(ctx context.Context, task *modules.Task) error { +func updateGlobalConfigProfile(ctx context.Context) error { cfgLock.Lock() defer cfgLock.Unlock() @@ -132,25 +129,23 @@ func updateGlobalConfigProfile(ctx context.Context, task *modules.Task) error { // If there was any error, try again later until it succeeds. if lastErr == nil { - module.Resolve(globalConfigProfileErrorID) + module.States.Remove(globalConfigProfileErrorID) } else { // Create task after first failure. - if task == nil { - task = module.NewTask( - "retry updating global config profile", - updateGlobalConfigProfile, - ) - } // Schedule task. - task.Schedule(time.Now().Add(15 * time.Second)) + module.mgr.Delay("retry updating global config profile", 15*time.Second, + func(w *mgr.WorkerCtx) error { + return updateGlobalConfigProfile(w.Ctx()) + }) // Add module warning to inform user. - module.Warning( - globalConfigProfileErrorID, - "Internal Settings Failure", - fmt.Sprintf("Some global settings might not be applied correctly. You can try restarting the Portmaster to resolve this problem. Error: %s", err), - ) + module.States.Add(mgr.State{ + ID: globalConfigProfileErrorID, + Name: "Internal Settings Failure", + Message: fmt.Sprintf("Some global settings might not be applied correctly. You can try restarting the Portmaster to resolve this problem. Error: %s", err), + Type: mgr.StateTypeWarning, + }) } return lastErr diff --git a/service/profile/get.go b/service/profile/get.go index 36e3e0d9f..7fdc4b413 100644 --- a/service/profile/get.go +++ b/service/profile/get.go @@ -1,7 +1,6 @@ package profile import ( - "context" "errors" "fmt" "path" @@ -13,6 +12,7 @@ import ( "github.com/safing/portmaster/base/database/record" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/notifications" + "github.com/safing/portmaster/service/mgr" ) var getProfileLock sync.Mutex @@ -147,8 +147,8 @@ func GetLocalProfile(id string, md MatchingData, createProfileCallback func() *P // Trigger further metadata fetching from system if profile was created. if created && profile.UsePresentationPath && !special { - module.StartWorker("get profile metadata", func(ctx context.Context) error { - return profile.updateMetadataFromSystem(ctx, md) + module.mgr.Go("get profile metadata", func(wc *mgr.WorkerCtx) error { + return profile.updateMetadataFromSystem(wc.Ctx(), md) }) } diff --git a/service/profile/migrations.go b/service/profile/migrations.go index dfa6bf821..72829b0ce 100644 --- a/service/profile/migrations.go +++ b/service/profile/migrations.go @@ -11,6 +11,7 @@ import ( "github.com/safing/portmaster/base/database/migration" "github.com/safing/portmaster/base/database/query" "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/profile/binmeta" ) @@ -129,11 +130,12 @@ func migrateIcons(ctx context.Context, _, to *version.Version, db *database.Inte if lastErr != nil { // Normally, an icon migration would not be such a big error, but this is a test // run for the profile IDs and we absolutely need to know if anything went wrong. - module.Error( - "migration-failed", - "Profile Migration Failed", - fmt.Sprintf("Failed to migrate icons of %d profiles (out of %d pending). The last error was: %s\n\nPlease restart Portmaster to try the migration again.", failed, total, lastErr), - ) + module.States.Add(mgr.State{ + ID: "migration-failed", + Name: "Profile Migration Failed", + Message: fmt.Sprintf("Failed to migrate icons of %d profiles (out of %d pending). The last error was: %s\n\nPlease restart Portmaster to try the migration again.", failed, total, lastErr), + Type: mgr.StateTypeError, + }) return fmt.Errorf("failed to migrate %d profiles (out of %d pending) - last error: %w", failed, total, lastErr) } @@ -217,11 +219,12 @@ func migrateToDerivedIDs(ctx context.Context, _, to *version.Version, db *databa // Log migration failure and try again next time. if lastErr != nil { - module.Error( - "migration-failed", - "Profile Migration Failed", - fmt.Sprintf("Failed to migrate profile IDs of %d profiles (out of %d pending). The last error was: %s\n\nPlease restart Portmaster to try the migration again.", failed, total, lastErr), - ) + module.States.Add(mgr.State{ + ID: "migration-failed", + Name: "Profile Migration Failed", + Message: fmt.Sprintf("Failed to migrate profile IDs of %d profiles (out of %d pending). The last error was: %s\n\nPlease restart Portmaster to try the migration again.", failed, total, lastErr), + Type: mgr.StateTypeError, + }) return fmt.Errorf("failed to migrate %d profiles (out of %d pending) - last error: %w", failed, total, lastErr) } diff --git a/service/profile/module.go b/service/profile/module.go index 162b63cec..960719d18 100644 --- a/service/profile/module.go +++ b/service/profile/module.go @@ -6,11 +6,11 @@ import ( "os" "sync/atomic" + "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/base/database" "github.com/safing/portmaster/base/database/migration" "github.com/safing/portmaster/base/dataroot" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" _ "github.com/safing/portmaster/service/core/base" "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/profile/binmeta" @@ -37,6 +37,8 @@ type ProfileModule struct { EventConfigChange *mgr.EventMgr[string] EventDelete *mgr.EventMgr[string] EventMigrated *mgr.EventMgr[[]string] + + States *mgr.StateMgr } func (pm *ProfileModule) Start(m *mgr.Manager) error { @@ -46,6 +48,8 @@ func (pm *ProfileModule) Start(m *mgr.Manager) error { pm.EventDelete = mgr.NewEventMgr[string](DeletedEvent, m) pm.EventMigrated = mgr.NewEventMgr[[]string](MigratedEvent, m) + pm.States = mgr.NewStateMgr(m) + if err := prep(); err != nil { return err } @@ -115,7 +119,7 @@ func start() error { module.mgr.Go("clean active profiles", cleanActiveProfiles) - err = updateGlobalConfigProfile(module.mgr.Ctx(), nil) + err = updateGlobalConfigProfile(module.mgr.Ctx()) if err != nil { log.Warningf("profile: error during loading global profile from configuration: %s", err) } @@ -148,4 +152,6 @@ func NewModule(instance instance) (*ProfileModule, error) { return module, nil } -type instance interface{} +type instance interface { + Config() *config.Config +} diff --git a/service/resolver/api.go b/service/resolver/api.go index c16f031ec..e1814f18e 100644 --- a/service/resolver/api.go +++ b/service/resolver/api.go @@ -11,7 +11,6 @@ func registerAPI() error { if err := api.RegisterEndpoint(api.Endpoint{ Path: "dns/clear", Write: api.PermitUser, - BelongsTo: module, ActionFunc: clearNameCacheHandler, Name: "Clear cached DNS records", Description: "Deletes all saved DNS records from the database.", @@ -22,7 +21,6 @@ func registerAPI() error { if err := api.RegisterEndpoint(api.Endpoint{ Path: "dns/resolvers", Read: api.PermitAnyone, - BelongsTo: module, StructFunc: exportDNSResolvers, Name: "List DNS Resolvers", Description: "List currently configured DNS resolvers and their status.", @@ -31,9 +29,8 @@ func registerAPI() error { } if err := api.RegisterEndpoint(api.Endpoint{ - Path: `dns/cache/{query:[a-z0-9\.-]{0,512}\.[A-Z]{1,32}}`, - Read: api.PermitUser, - BelongsTo: module, + Path: `dns/cache/{query:[a-z0-9\.-]{0,512}\.[A-Z]{1,32}}`, + Read: api.PermitUser, RecordFunc: func(r *api.Request) (record.Record, error) { return recordDatabase.Get(nameRecordsKeyPrefix + r.URLVars["query"]) }, diff --git a/service/resolver/failing.go b/service/resolver/failing.go index c8e011bd0..8c562642f 100644 --- a/service/resolver/failing.go +++ b/service/resolver/failing.go @@ -1,11 +1,9 @@ package resolver import ( - "context" "time" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" ) @@ -72,7 +70,7 @@ func (brc *BasicResolverConn) ResetFailure() { } } -func checkFailingResolvers(wc *mgr.WorkerCtx) error { //, task *modules.Task) error { +func checkFailingResolvers(wc *mgr.WorkerCtx) error { var resolvers []*Resolver // Make a copy of the resolver list. @@ -118,9 +116,7 @@ func checkFailingResolvers(wc *mgr.WorkerCtx) error { //, task *modules.Task) er } // Set next execution time. - if task != nil { - task.Schedule(time.Now().Add(time.Duration(nameserverRetryRate()) * time.Second)) - } + module.mgr.Delay("check failing resolvers", time.Duration(nameserverRetryRate())*time.Second, checkFailingResolvers) return nil } diff --git a/service/resolver/main.go b/service/resolver/main.go index 6d50dcc89..7083cbcc1 100644 --- a/service/resolver/main.go +++ b/service/resolver/main.go @@ -8,29 +8,30 @@ import ( "strings" "sync" "sync/atomic" - "time" "github.com/tevino/abool" "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/base/utils/debug" _ "github.com/safing/portmaster/service/core/base" "github.com/safing/portmaster/service/intel" "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" - "github.com/safing/portmaster/spn/captain" ) type ResolverModule struct { mgr *mgr.Manager instance instance + + States *mgr.StateMgr } func (rm *ResolverModule) Start(m *mgr.Manager) error { rm.mgr = m + rm.States = mgr.NewStateMgr(m) + if err := prep(); err != nil { return err } @@ -75,7 +76,7 @@ func start() error { ) // Force resolvers to reconnect when SPN has connected. - module.instance.Captain().EventSPNConnected.AddCallback( + module.instance.GetEventSPNConnected().AddCallback( "force resolver reconnect", func(ctx *mgr.WorkerCtx, _ struct{}) (bool, error) { ForceResolverReconnect(ctx.Ctx()) @@ -98,7 +99,7 @@ func start() error { }) // Check failing resolvers regularly and when the network changes. - module.mgr.Repeat("check failing resolvers", 1*time.Minute, checkFailingResolvers) + module.mgr.Do("check failing resolvers", checkFailingResolvers) module.instance.NetEnv().EventNetworkChange.AddCallback( "check failing resolvers", func(wc *mgr.WorkerCtx, _ struct{}) (bool, error) { @@ -106,7 +107,7 @@ func start() error { return false, nil }) - module.mgr.Repeat("suggest using stale cache", 2*time.Minute, suggestUsingStaleCacheTask) + module.mgr.Go("suggest using stale cache", suggestUsingStaleCacheTask) module.mgr.Go( "mdns handler", @@ -180,7 +181,8 @@ This notification will go away when Portmaster detects a working configured DNS notifications.Notify(n) failingResolverNotification = n - n.AttachToModule(module) + // TODO(vladimir): is this needed? + // n.AttachToModule(module) } func resetFailingResolversNotification() { @@ -198,7 +200,7 @@ func resetFailingResolversNotification() { } // Additionally, resolve the module error, if not done through the notification. - module.Resolve(failingResolverErrorID) + module.States.Remove(failingResolverErrorID) } // AddToDebugInfo adds the system status to the given debug.Info. @@ -259,6 +261,6 @@ func New(instance instance) (*ResolverModule, error) { type instance interface { NetEnv() *netenv.NetEnv - Captain() *captain.Captain Config() *config.Config + GetEventSPNConnected() *mgr.EventMgr[struct{}] } diff --git a/service/resolver/metrics.go b/service/resolver/metrics.go index c2589ee8f..f0062c9ef 100644 --- a/service/resolver/metrics.go +++ b/service/resolver/metrics.go @@ -1,12 +1,10 @@ package resolver import ( - "context" "sync/atomic" "time" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/service/mgr" ) @@ -50,11 +48,12 @@ func resetSlowQueriesSensorValue() { var suggestUsingStaleCacheNotification *notifications.Notification -func suggestUsingStaleCacheTask(_ *mgr.WorkerCtx) error { // t *modules.Task) error { +func suggestUsingStaleCacheTask(_ *mgr.WorkerCtx) error { + scheduleNextCall := true switch { case useStaleCache() || useStaleCacheConfigOption.IsSetByUser(): // If setting is already active, disable task repeating. - t.Repeat(0) + scheduleNextCall = false // Delete local reference, if used. if suggestUsingStaleCacheNotification != nil { @@ -103,6 +102,9 @@ func suggestUsingStaleCacheTask(_ *mgr.WorkerCtx) error { // t *modules.Task) er notifications.Notify(suggestUsingStaleCacheNotification) } + if scheduleNextCall { + module.mgr.Delay("suggest using stale cache", 2*time.Minute, suggestUsingStaleCacheTask) + } resetSlowQueriesSensorValue() return nil } diff --git a/service/resolver/resolve.go b/service/resolver/resolve.go index 1f85e59ef..403231f18 100644 --- a/service/resolver/resolve.go +++ b/service/resolver/resolve.go @@ -14,6 +14,7 @@ import ( "github.com/safing/portmaster/base/database" "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" ) @@ -307,8 +308,8 @@ func startAsyncQuery(ctx context.Context, q *Query, currentRRCache *RRCache) { } // resolve async - module.StartWorker("resolve async", func(asyncCtx context.Context) error { - tracingCtx, tracer := log.AddTracer(asyncCtx) + module.mgr.Go("resolve async", func(wc *mgr.WorkerCtx) error { + tracingCtx, tracer := log.AddTracer(wc.Ctx()) defer tracer.Submit() tracer.Tracef("resolver: resolving %s async", q.ID()) _, err := resolveAndCache(tracingCtx, q, nil) @@ -412,7 +413,7 @@ func resolveAndCache(ctx context.Context, q *Query, oldCache *RRCache) (rrCache // start resolving for _, resolver := range resolvers { - if module.IsStopping() { + if module.mgr.IsDone() { return nil, ErrShuttingDown } diff --git a/service/resolver/resolver-mdns.go b/service/resolver/resolver-mdns.go index 0eb674f98..ab08bdcfb 100644 --- a/service/resolver/resolver-mdns.go +++ b/service/resolver/resolver-mdns.go @@ -342,7 +342,7 @@ func listenForDNSPackets(ctx context.Context, conn *net.UDPConn, messages chan * for { n, err := conn.Read(buf) if err != nil { - if module.IsStopping() { + if module.mgr.IsDone() { return nil } log.Debugf("resolver: failed to read packet: %s", err) diff --git a/service/resolver/resolver-tcp.go b/service/resolver/resolver-tcp.go index d94a40e3b..0e14f3a71 100644 --- a/service/resolver/resolver-tcp.go +++ b/service/resolver/resolver-tcp.go @@ -13,6 +13,7 @@ import ( "github.com/tevino/abool" "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" ) @@ -119,16 +120,18 @@ func (tr *TCPResolver) getOrCreateResolverConn(ctx context.Context) (*tcpResolve log.Warningf("resolver: heartbeat for dns client %s failed", tr.resolver.Info.DescriptiveName()) case <-ctx.Done(): return nil, ctx.Err() - case <-module.Stopping(): - return nil, ErrShuttingDown + // TODO(vladimir): there is no need for this right? + // case <-module.Stopping(): + // return nil, ErrShuttingDown } } else { // If there is no resolver, check if we are shutting down before dialing! select { case <-ctx.Done(): return nil, ctx.Err() - case <-module.Stopping(): - return nil, ErrShuttingDown + // TODO(vladimir): there is no need for this right? + // case <-module.Stopping(): + // return nil, ErrShuttingDown default: } } @@ -175,7 +178,7 @@ func (tr *TCPResolver) getOrCreateResolverConn(ctx context.Context) (*tcpResolve } // Start worker. - module.StartWorker("dns client", resolverConn.handler) + module.mgr.Go("dns client", resolverConn.handler) // Set resolver conn for reuse. tr.resolverConn = resolverConn @@ -204,8 +207,9 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) { case resolverConn.queries <- tq: case <-ctx.Done(): return nil, ctx.Err() - case <-module.Stopping(): - return nil, ErrShuttingDown + // TODO(vladimir): there is no need for this right? + // case <-module.Stopping(): + // return nil, ErrShuttingDown case <-time.After(defaultRequestTimeout): return nil, ErrTimeout } @@ -216,8 +220,9 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) { case reply = <-tq.Response: case <-ctx.Done(): return nil, ctx.Err() - case <-module.Stopping(): - return nil, ErrShuttingDown + // TODO(vladimir): there is no need for this right? + // case <-module.Stopping(): + // return nil, ErrShuttingDown case <-time.After(defaultRequestTimeout): return nil, ErrTimeout } @@ -282,9 +287,9 @@ func (trc *tcpResolverConn) shutdown() { } } -func (trc *tcpResolverConn) handler(workerCtx context.Context) error { +func (trc *tcpResolverConn) handler(workerCtx *mgr.WorkerCtx) error { // Set up context and cleanup. - trc.ctx, trc.cancelCtx = context.WithCancel(workerCtx) + trc.ctx, trc.cancelCtx = context.WithCancel(workerCtx.Ctx()) defer trc.shutdown() // Set up variables. @@ -292,7 +297,7 @@ func (trc *tcpResolverConn) handler(workerCtx context.Context) error { ttlTimer := time.After(defaultClientTTL) // Start connection reader. - module.StartWorker("dns client reader", trc.reader) + module.mgr.Go("dns client reader", trc.reader) // Handle requests. for { @@ -416,7 +421,7 @@ func (trc *tcpResolverConn) handleQueryResponse(msg *dns.Msg) { } } -func (trc *tcpResolverConn) reader(workerCtx context.Context) error { +func (trc *tcpResolverConn) reader(workerCtx *mgr.WorkerCtx) error { defer trc.cancelCtx() for { diff --git a/service/resolver/resolvers.go b/service/resolver/resolvers.go index 055aa6311..4c2335201 100644 --- a/service/resolver/resolvers.go +++ b/service/resolver/resolvers.go @@ -15,6 +15,7 @@ import ( "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/utils" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/network/netutils" ) @@ -369,15 +370,15 @@ func loadResolvers() { defer resolversLock.Unlock() // Resolve module error about missing resolvers. - module.Resolve(missingResolversErrorID) + module.States.Remove(missingResolversErrorID) // Check if settings were changed and clear name cache when they did. newResolverConfig := configuredNameServers() if len(currentResolverConfig) > 0 && !utils.StringSliceEqual(currentResolverConfig, newResolverConfig) { - module.StartWorker("clear dns cache", func(ctx context.Context) error { + module.mgr.Go("clear dns cache", func(ctx *mgr.WorkerCtx) error { log.Info("resolver: clearing dns cache due to changed resolver config") - _, err := clearNameCache(ctx) + _, err := clearNameCache(ctx.Ctx()) return err }) } @@ -393,18 +394,20 @@ func loadResolvers() { newResolvers = getConfiguredResolvers(defaultNameServers) if len(newResolvers) > 0 { log.Warning("resolver: no (valid) dns server found in config or system, falling back to global defaults") - module.Warning( - missingResolversErrorID, - "Using Factory Default DNS Servers", - "The Portmaster could not find any (valid) DNS servers in the settings or system. In order to prevent being disconnected, the factory defaults are being used instead. If you just switched your network, this should be resolved shortly.", - ) + module.States.Add(mgr.State{ + ID: missingResolversErrorID, + Name: "Using Factory Default DNS Servers", + Message: "The Portmaster could not find any (valid) DNS servers in the settings or system. In order to prevent being disconnected, the factory defaults are being used instead. If you just switched your network, this should be resolved shortly.", + Type: mgr.StateTypeWarning, + }) } else { log.Critical("resolver: no (valid) dns server found in config, system or global defaults") - module.Error( - missingResolversErrorID, - "No DNS Servers Configured", - "The Portmaster could not find any (valid) DNS servers in the settings or system. You will experience severe connectivity problems until resolved. If you just switched your network, this should be resolved shortly.", - ) + module.States.Add(mgr.State{ + ID: missingResolversErrorID, + Name: "No DNS Servers Configured", + Message: "The Portmaster could not find any (valid) DNS servers in the settings or system. You will experience severe connectivity problems until resolved. If you just switched your network, this should be resolved shortly.", + Type: mgr.StateTypeError, + }) } } diff --git a/service/sync/profile.go b/service/sync/profile.go index cb75d37da..02ae07fa7 100644 --- a/service/sync/profile.go +++ b/service/sync/profile.go @@ -97,8 +97,7 @@ func registerProfileAPI() error { Field: "id", Description: "Specify scoped profile ID to export.", }}, - BelongsTo: module, - DataFunc: handleExportProfile, + DataFunc: handleExportProfile, }); err != nil { return err } @@ -128,7 +127,6 @@ func registerProfileAPI() error { Description: "Allow importing of unknown values.", }, }, - BelongsTo: module, StructFunc: handleImportProfile, }); err != nil { return err diff --git a/service/sync/setting_single.go b/service/sync/setting_single.go index 9fd9c9a6d..c738c1021 100644 --- a/service/sync/setting_single.go +++ b/service/sync/setting_single.go @@ -44,8 +44,7 @@ func registerSingleSettingAPI() error { Field: "key", Description: "Specify which settings key to export.", }}, - BelongsTo: module, - DataFunc: handleExportSingleSetting, + DataFunc: handleExportSingleSetting, }); err != nil { return err } @@ -69,7 +68,6 @@ func registerSingleSettingAPI() error { Field: "validate", Description: "Validate only.", }}, - BelongsTo: module, StructFunc: handleImportSingleSetting, }); err != nil { return err diff --git a/service/sync/settings.go b/service/sync/settings.go index 4d39b2874..3a7dd8e0e 100644 --- a/service/sync/settings.go +++ b/service/sync/settings.go @@ -52,8 +52,7 @@ func registerSettingsAPI() error { Field: "key", Description: "Optionally select a single setting to export. Repeat to export selection.", }}, - BelongsTo: module, - DataFunc: handleExportSettings, + DataFunc: handleExportSettings, }); err != nil { return err } @@ -81,7 +80,6 @@ func registerSettingsAPI() error { Field: "allowUnknown", Description: "Allow importing of unknown values.", }}, - BelongsTo: module, StructFunc: handleImportSettings, }); err != nil { return err diff --git a/service/updates/main.go b/service/updates/main.go index 46bfdc09d..5f4d1dee6 100644 --- a/service/updates/main.go +++ b/service/updates/main.go @@ -9,11 +9,11 @@ import ( "runtime" "time" - "github.com/safing/portbase/modules" "github.com/safing/portmaster/base/database" "github.com/safing/portmaster/base/dataroot" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/updater" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/updates/helper" ) @@ -49,7 +49,7 @@ var ( userAgentFromFlag string updateServerFromFlag string - updateTask *modules.Task + // updateTask *modules.Task updateASAP bool disableTaskSchedule bool @@ -204,9 +204,9 @@ func start() error { // MaxDelay(30 * time.Minute) } - if updateASAP { - updateTask.StartASAP() - } + // if updateASAP { + // updateTask.StartASAP() + // } // react to upgrades if err := initUpgrader(); err != nil { @@ -221,8 +221,9 @@ func start() error { // TriggerUpdate queues the update task to execute ASAP. func TriggerUpdate(forceIndexCheck, downloadAll bool) error { switch { - case !module.Online(): - updateASAP = true + // FIXME(vladimir): provide alternative for this + // case !module.Online(): + // updateASAP = true case !forceIndexCheck && !enableSoftwareUpdates() && !enableIntelUpdates(): return errors.New("automatic updating is disabled") @@ -236,11 +237,12 @@ func TriggerUpdate(forceIndexCheck, downloadAll bool) error { } // If index check if forced, start quicker. - if forceIndexCheck { - updateTask.StartASAP() - } else { - updateTask.Queue() - } + // FIXME(vladimir): provide alternative for this + // if forceIndexCheck { + // updateTask.StartASAP() + // } else { + // updateTask.Queue() + // } } log.Debugf("updates: triggering update to run as soon as possible") @@ -251,17 +253,18 @@ func TriggerUpdate(forceIndexCheck, downloadAll bool) error { // If called, updates are only checked when TriggerUpdate() // is called. func DisableUpdateSchedule() error { - switch module.Status() { - case modules.StatusStarting, modules.StatusOnline, modules.StatusStopping: - return errors.New("module already online") - } + // TODO: Updater state should be always on + // switch module.Status() { + // case modules.StatusStarting, modules.StatusOnline, modules.StatusStopping: + // return errors.New("module already online") + // } disableTaskSchedule = true return nil } -func checkForUpdates(ctx context.Context) (err error) { +func checkForUpdates(ctx *mgr.WorkerCtx) (err error) { // Set correct error if context was canceled. defer func() { select { @@ -294,12 +297,12 @@ func checkForUpdates(ctx context.Context) (err error) { notifyUpdateCheckFailed(forceIndexCheck, err) }() - if err = registry.UpdateIndexes(ctx); err != nil { + if err = registry.UpdateIndexes(ctx.Ctx()); err != nil { err = fmt.Errorf("failed to update indexes: %w", err) return //nolint:nakedret // TODO: Would "return err" work with the defer? } - err = registry.DownloadUpdates(ctx, downloadAll) + err = registry.DownloadUpdates(ctx.Ctx(), downloadAll) if err != nil { err = fmt.Errorf("failed to download updates: %w", err) return //nolint:nakedret // TODO: Would "return err" work with the defer? @@ -334,9 +337,9 @@ func stop() error { // RootPath returns the root path used for storing updates. func RootPath() string { - if !module.Online() { - return "" - } + // if !module.Online() { + // return "" + // } return registry.StorageDir().Path } diff --git a/service/updates/notify.go b/service/updates/notify.go index 662a1b82c..b20bd92fe 100644 --- a/service/updates/notify.go +++ b/service/updates/notify.go @@ -164,5 +164,5 @@ func notifyUpdateCheckFailed(force bool, err error) { ResultAction: "display", }, }, - ).AttachToModule(module) + ) // FIXME: add replacement for this .AttachToModule(module) } diff --git a/spn/access/client.go b/spn/access/client.go index b52e88433..70381c855 100644 --- a/spn/access/client.go +++ b/spn/access/client.go @@ -57,15 +57,16 @@ func makeClientRequest(opts *clientRequestOptions) (resp *http.Response, err err // Get context for request. var ctx context.Context var cancel context.CancelFunc - if module.Online() { - // Only use module context if online. - ctx, cancel = context.WithTimeout(module.Ctx, opts.requestTimeout) - defer cancel() - } else { - // Otherwise, use the background context. - ctx, cancel = context.WithTimeout(context.Background(), opts.requestTimeout) - defer cancel() - } + // TODO(vladimir): can the module not be online? + // if module.Online() { + // Only use module context if online. + ctx, cancel = context.WithTimeout(module.mgr.Ctx(), opts.requestTimeout) + defer cancel() + // } else { + // // Otherwise, use the background context. + // ctx, cancel = context.WithTimeout(context.Background(), opts.requestTimeout) + // defer cancel() + // } // Create new request. request, err := http.NewRequestWithContext(ctx, opts.method, opts.url, nil) diff --git a/spn/access/module.go b/spn/access/module.go index 2ff60cdba..7f5ada2df 100644 --- a/spn/access/module.go +++ b/spn/access/module.go @@ -112,8 +112,8 @@ func stop() error { func UpdateAccount(_ *mgr.WorkerCtx) error { //, task *modules.Task) error { // Retry sooner if the token issuer is failing. defer func() { - if tokenIssuerIsFailing.IsSet() && task != nil { - task.Schedule(time.Now().Add(tokenIssuerRetryDuration)) + if tokenIssuerIsFailing.IsSet() { + module.mgr.Delay("update account", tokenIssuerRetryDuration, UpdateAccount) } }() @@ -144,15 +144,17 @@ func UpdateAccount(_ *mgr.WorkerCtx) error { //, task *modules.Task) error { case time.Until(*u.Subscription.EndsAt) < 24*time.Hour && time.Since(*u.Subscription.EndsAt) < 24*time.Hour: - // Update account every hour 24h hours before and after the subscription ends. - task.Schedule(time.Now().Add(time.Hour)) + // Update account every hour for 24h hours before and after the subscription ends. + // TODO(vladimir): Go rotunes will leak if this is called more then once. Figure out a way to test if this is already running. + module.mgr.Delay("update account", 1*time.Hour, UpdateAccount) case u.Subscription.NextBillingDate == nil: // No auto-subscription. case time.Until(*u.Subscription.NextBillingDate) < 24*time.Hour && time.Since(*u.Subscription.NextBillingDate) < 24*time.Hour: // Update account every hour 24h hours before and after the next billing date. - task.Schedule(time.Now().Add(time.Hour)) + // TODO(vladimir): Go rotunes will leak if this is called more then once. Figure out a way to test if this is already running. + module.mgr.Delay("update account", 1*time.Hour, UpdateAccount) } return nil @@ -181,11 +183,12 @@ func tokenIssuerFailed() { if !tokenIssuerIsFailing.SetToIf(false, true) { return } - if !module.Online() { - return - } + // TODO(vladimir): Do we need this check? + // if !module.Online() { + // return + // } - accountUpdateTask.Schedule(time.Now().Add(tokenIssuerRetryDuration)) + module.mgr.Delay("update account", tokenIssuerRetryDuration, UpdateAccount) } // IsLoggedIn returns whether a User is currently logged in. diff --git a/spn/access/storage.go b/spn/access/storage.go index 6f6a924f3..617d3c66a 100644 --- a/spn/access/storage.go +++ b/spn/access/storage.go @@ -120,7 +120,7 @@ func clearTokens() { } // Purge database storage prefix. - ctx, cancel := context.WithTimeout(module.Ctx, 10*time.Second) + ctx, cancel := context.WithTimeout(module.mgr.Ctx(), 10*time.Second) defer cancel() n, err := db.Purge(ctx, query.New(fmt.Sprintf(tokenStorageKeyTemplate, ""))) if err != nil { diff --git a/spn/access/token/module_test.go b/spn/access/token/module_test.go index b3cc49b8a..2d00460aa 100644 --- a/spn/access/token/module_test.go +++ b/spn/access/token/module_test.go @@ -3,7 +3,6 @@ package token import ( "testing" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/service/core/pmtesting" ) diff --git a/spn/access/zones.go b/spn/access/zones.go index 444ebf2da..585756735 100644 --- a/spn/access/zones.go +++ b/spn/access/zones.go @@ -141,12 +141,13 @@ func initializeTestZone() error { func shouldRequestTokensHandler(_ token.Handler) { // accountUpdateTask is always set in client mode and when the module is online. // Check if it's set in case this gets executed in other circumstances. - if accountUpdateTask == nil { - log.Warningf("spn/access: trying to trigger account update, but the task is not available") - return - } + // if accountUpdateTask == nil { + // log.Warningf("spn/access: trying to trigger account update, but the task is not available") + // return + // } - accountUpdateTask.StartASAP() + // accountUpdateTask.StartASAP() + module.mgr.Go("update account", UpdateAccount) } // GetTokenAmount returns the amount of tokens for the given zones. diff --git a/spn/captain/api.go b/spn/captain/api.go index 0b37828b7..a3108facc 100644 --- a/spn/captain/api.go +++ b/spn/captain/api.go @@ -1,12 +1,7 @@ package captain import ( - "errors" - "fmt" - "github.com/safing/portmaster/base/api" - "github.com/safing/portmaster/base/database" - "github.com/safing/portmaster/base/database/query" ) const ( @@ -29,39 +24,41 @@ func registerAPIEndpoints() error { } func handleReInit(ar *api.Request) (msg string, err error) { - // Disable module and check - changed := module.Disable() - if !changed { - return "", errors.New("can only re-initialize when the SPN is enabled") - } + // FIXME: make a better way to disable and enable spn + // // Disable module and check + // changed := module.Disable() + // if !changed { + // return "", errors.New("can only re-initialize when the SPN is enabled") + // } - // Run module manager. - err = modules.ManageModules() - if err != nil { - return "", fmt.Errorf("failed to stop SPN: %w", err) - } + // // Run module manager. + // err = modules.ManageModules() + // if err != nil { + // return "", fmt.Errorf("failed to stop SPN: %w", err) + // } - // Delete SPN cache. - db := database.NewInterface(&database.Options{ - Local: true, - Internal: true, - }) - deletedRecords, err := db.Purge(ar.Context(), query.New("cache:spn/")) - if err != nil { - return "", fmt.Errorf("failed to delete SPN cache: %w", err) - } + // // Delete SPN cache. + // db := database.NewInterface(&database.Options{ + // Local: true, + // Internal: true, + // }) + // deletedRecords, err := db.Purge(ar.Context(), query.New("cache:spn/")) + // if err != nil { + // return "", fmt.Errorf("failed to delete SPN cache: %w", err) + // } - // Enable module. - module.Enable() + // // Enable module. + // module.Enable() - // Run module manager. - err = modules.ManageModules() - if err != nil { - return "", fmt.Errorf("failed to start SPN after cache reset: %w", err) - } + // // Run module manager. + // err = modules.ManageModules() + // if err != nil { + // return "", fmt.Errorf("failed to start SPN after cache reset: %w", err) + // } - return fmt.Sprintf( - "Completed SPN re-initialization and deleted %d cache records in the process.", - deletedRecords, - ), nil + // return fmt.Sprintf( + // "Completed SPN re-initialization and deleted %d cache records in the process.", + // deletedRecords, + // ), nil + return "", nil } diff --git a/spn/captain/bootstrap.go b/spn/captain/bootstrap.go index 6516b937a..4ccb3370e 100644 --- a/spn/captain/bootstrap.go +++ b/spn/captain/bootstrap.go @@ -74,7 +74,7 @@ func bootstrapWithUpdates() error { return errors.New("using the bootstrap-file argument disables bootstrapping via the update system") } - return updateSPNIntel(module.Ctx, nil) + return updateSPNIntel(module.mgr.Ctx(), nil) } // loadBootstrapFile loads a file with bootstrap hub entries and imports them. diff --git a/spn/captain/client.go b/spn/captain/client.go index d08727f63..4edc6264d 100644 --- a/spn/captain/client.go +++ b/spn/captain/client.go @@ -77,15 +77,16 @@ func clientManager(ctx *mgr.WorkerCtx) error { ready.UnSet() netenv.ConnectedToSPN.UnSet() resetSPNStatus(StatusDisabled, true) - module.Resolve("") + module.States.Clear() clientStopHomeHub(ctx.Ctx()) }() - module.Hint( - "spn:establishing-home-hub", - "Connecting to SPN...", - "Connecting to the SPN network is in progress.", - ) + module.States.Add(mgr.State{ + ID: "spn:establishing-home-hub", + Name: "Connecting to SPN...", + Message: "Connecting to the SPN network is in progress.", + Type: mgr.StateTypeHint, + }) // TODO: When we are starting and the SPN module is faster online than the // nameserver, then updating the account will fail as the DNS query is @@ -98,7 +99,8 @@ func clientManager(ctx *mgr.WorkerCtx) error { return nil } - healthCheckTicker := module.NewSleepyTicker(clientHealthCheckTickDuration, clientHealthCheckTickDurationSleepMode) + module.healthCheckTicker = mgr.NewSleepyTicker(clientHealthCheckTickDuration, clientHealthCheckTickDurationSleepMode) + defer module.healthCheckTicker.Stop() reconnect: for { @@ -180,7 +182,7 @@ reconnect: // Wait for signal to run maintenance again. select { - case <-healthCheckTicker.Wait(): + case <-module.healthCheckTicker.Wait(): case <-clientHealthCheckTrigger: case <-crew.ConnectErrors(): case <-clientNetworkChangedFlag.Signal(): @@ -226,7 +228,9 @@ func clientCheckAccountAndTokens(ctx context.Context) clientComponentResult { `Please restart Portmaster.`, // TODO: Add restart button. // TODO: Use special UI restart action in order to reload UI on restart. - ).AttachToModule(module) + ) + // TODO(vladimir): this is not needed right + // .AttachToModule(module) resetSPNStatus(StatusFailed, true) log.Errorf("spn/captain: client internal error: %s", err) return clientResultReconnect @@ -239,7 +243,9 @@ func clientCheckAccountAndTokens(ctx context.Context) clientComponentResult { "SPN Login Required", `Please log in to access the SPN.`, spnLoginButton, - ).AttachToModule(module) + ) + // TODO(vladimir): this is not needed right + // .AttachToModule(module) resetSPNStatus(StatusFailed, true) log.Warningf("spn/captain: enabled but not logged in") return clientResultReconnect @@ -257,7 +263,9 @@ func clientCheckAccountAndTokens(ctx context.Context) clientComponentResult { "spn:failed-to-update-user", "SPN Account Server Error", fmt.Sprintf(`The status of your SPN account could not be updated: %s`, err), - ).AttachToModule(module) + ) + // TODO(vladimir): this is not needed right + // .AttachToModule(module) resetSPNStatus(StatusFailed, true) log.Errorf("spn/captain: failed to update ineligible account: %s", err) return clientResultReconnect @@ -274,7 +282,9 @@ func clientCheckAccountAndTokens(ctx context.Context) clientComponentResult { "SPN Not Included In Package", "Your current Portmaster Package does not include access to the SPN. Please upgrade your package on the Account Page.", spnOpenAccountPage, - ).AttachToModule(module) + ) + // TODO(vladimir): this is not needed right + // .AttachToModule(module) resetSPNStatus(StatusFailed, true) return clientResultReconnect } @@ -289,7 +299,9 @@ func clientCheckAccountAndTokens(ctx context.Context) clientComponentResult { "Portmaster Package Issue", "Cannot enable SPN: "+message, spnOpenAccountPage, - ).AttachToModule(module) + ) + // TODO(vladimir): this is not needed right + // .AttachToModule(module) resetSPNStatus(StatusFailed, true) return clientResultReconnect } @@ -309,7 +321,9 @@ func clientCheckAccountAndTokens(ctx context.Context) clientComponentResult { "spn:tokens-exhausted", "SPN Access Tokens Exhausted", `The Portmaster failed to get new access tokens to access the SPN. The Portmaster will automatically retry to get new access tokens.`, - ).AttachToModule(module) + ) + // TODO(vladimir): this is not needed right + // .AttachToModule(module) resetSPNStatus(StatusFailed, false) } return clientResultRetry @@ -357,7 +371,9 @@ func clientConnectToHomeHub(ctx context.Context) clientComponentResult { Key: CfgOptionHomeHubPolicyKey, }, }, - ).AttachToModule(module) + ) + // TODO(vladimir): this is not needed right + // .AttachToModule(module) case errors.Is(err, ErrReInitSPNSuggested): notifications.NotifyError( @@ -373,14 +389,18 @@ func clientConnectToHomeHub(ctx context.Context) clientComponentResult { ResultAction: "display", }, }, - ).AttachToModule(module) + ) + // TODO(vladimir): this is not needed right + // .AttachToModule(module) default: notifications.NotifyWarn( "spn:home-hub-failure", "SPN Failed to Connect", fmt.Sprintf("Failed to connect to a home hub: %s. The Portmaster will retry to connect automatically.", err), - ).AttachToModule(module) + ) + // TODO(vladimir): this is not needed right + // .AttachToModule(module) } return clientResultReconnect @@ -403,7 +423,7 @@ func clientSetActiveConnectionStatus(ctx context.Context) clientComponentResult } // Resolve any connection error. - module.Resolve("") + module.States.Clear() // Update SPN Status with connection information, if not already correctly set. spnStatus.Lock() diff --git a/spn/captain/hooks.go b/spn/captain/hooks.go index 6a60f7eab..c16d46ecf 100644 --- a/spn/captain/hooks.go +++ b/spn/captain/hooks.go @@ -1,8 +1,6 @@ package captain import ( - "time" - "github.com/safing/portmaster/service/updates" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/docks" @@ -34,7 +32,7 @@ func handleCraneUpdate(crane *docks.Crane) { func updateConnectionStatus() { // Delay updating status for a better chance to combine multiple changes. - statusUpdateTask.Schedule(time.Now().Add(maintainStatusUpdateDelay)) + module.mgr.Delay("maintain public status", maintainStatusUpdateDelay, maintainPublicStatus) // Check if we lost all connections and trigger a pending restart if we did. for _, crane := range docks.GetAllAssignedCranes() { diff --git a/spn/captain/intel.go b/spn/captain/intel.go index ff53bb4f2..f6bbe4fb8 100644 --- a/spn/captain/intel.go +++ b/spn/captain/intel.go @@ -6,8 +6,8 @@ import ( "os" "sync" - "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/base/updater" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/updates" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/hub" @@ -23,23 +23,13 @@ var ( ) func registerIntelUpdateHook() error { - if err := module.RegisterEventHook( - updates.ModuleName, - updates.ResourceUpdateEvent, - "update SPN intel", - updateSPNIntel, - ); err != nil { - return err - } + module.instance.Updates().EventResourcesUpdated.AddCallback("update SPN intel", func(wc *mgr.WorkerCtx, s struct{}) (cancel bool, err error) { + return false, updateSPNIntel(wc.Ctx(), nil) + }) - if err := module.RegisterEventHook( - "config", - config.ChangeEvent, - "update SPN intel", - updateSPNIntel, - ); err != nil { - return err - } + module.instance.Config().EventConfigChange.AddCallback("update SPN intel", func(wc *mgr.WorkerCtx, s struct{}) (cancel bool, err error) { + return false, updateSPNIntel(wc.Ctx(), nil) + }) return nil } diff --git a/spn/captain/module.go b/spn/captain/module.go index 20876a22b..b6157dc0f 100644 --- a/spn/captain/module.go +++ b/spn/captain/module.go @@ -8,15 +8,14 @@ import ( "sync/atomic" "time" - "github.com/safing/portbase/modules/subsystems" "github.com/safing/portmaster/base/api" "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/base/rng" "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/network/netutils" + "github.com/safing/portmaster/service/updates" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/crew" "github.com/safing/portmaster/spn/navigator" @@ -34,6 +33,11 @@ type Captain struct { mgr *mgr.Manager instance instance + shutdownFunc func(exitCode int) + + healthCheckTicker *mgr.SleepyTicker + + States *mgr.StateMgr EventSPNConnected *mgr.EventMgr[struct{}] } @@ -51,25 +55,31 @@ func (c *Captain) Stop(m *mgr.Manager) error { return stop() } +func (c *Captain) SetSleep(enabled bool) { + if c.healthCheckTicker != nil { + c.healthCheckTicker.SetSleep(enabled) + } +} + func init() { - subsystems.Register( - "spn", - "SPN", - "Safing Privacy Network", - module, - "config:spn/", - &config.Option{ - Name: "SPN Module", - Key: CfgOptionEnableSPNKey, - Description: "Start the Safing Privacy Network module. If turned off, the SPN is fully disabled on this device.", - OptType: config.OptTypeBool, - DefaultValue: false, - Annotations: config.Annotations{ - config.DisplayOrderAnnotation: cfgOptionEnableSPNOrder, - config.CategoryAnnotation: "General", - }, - }, - ) + // subsystems.Register( + // "spn", + // "SPN", + // "Safing Privacy Network", + // module, + // "config:spn/", + // &config.Option{ + // Name: "SPN Module", + // Key: CfgOptionEnableSPNKey, + // Description: "Start the Safing Privacy Network module. If turned off, the SPN is fully disabled on this device.", + // OptType: config.OptTypeBool, + // DefaultValue: false, + // Annotations: config.Annotations{ + // config.DisplayOrderAnnotation: cfgOptionEnableSPNOrder, + // config.CategoryAnnotation: "General", + // }, + // }, + // ) } func prep() error { @@ -126,7 +136,7 @@ func start() error { // Load identity. if err := loadPublicIdentity(); err != nil { // We cannot recover from this, set controlled failure (do not retry). - modules.SetExitStatusCode(controlledFailureExitCode) + module.shutdownFunc(controlledFailureExitCode) return err } @@ -134,7 +144,7 @@ func start() error { // Check if any networks are configured. if !conf.HubHasIPv4() && !conf.HubHasIPv6() { // We cannot recover from this, set controlled failure (do not retry). - modules.SetExitStatusCode(controlledFailureExitCode) + module.shutdownFunc(controlledFailureExitCode) return errors.New("no IP addresses for Hub configured (or detected)") } @@ -233,13 +243,14 @@ var ( ) // New returns a new Captain module. -func New(instance instance) (*Captain, error) { +func New(instance instance, shutdownFunc func(exitCode int)) (*Captain, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } module = &Captain{ - instance: instance, + instance: instance, + shutdownFunc: shutdownFunc, } return module, nil } @@ -247,4 +258,6 @@ func New(instance instance) (*Captain, error) { type instance interface { NetEnv() *netenv.NetEnv Patrol() *patrol.Patrol + Config() *config.Config + Updates() *updates.Updates } diff --git a/spn/captain/navigation.go b/spn/captain/navigation.go index 5b6210b73..8b2a57a66 100644 --- a/spn/captain/navigation.go +++ b/spn/captain/navigation.go @@ -7,7 +7,6 @@ import ( "time" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/service/intel" "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" @@ -255,7 +254,7 @@ optimize: } else if createdConnections < result.MaxConnect { attemptedConnections++ - crane, tErr := EstablishPublicLane(ctx, connectTo.Hub) + crane, tErr := EstablishPublicLane(ctx.Ctx(), connectTo.Hub) if !tErr.IsOK() { log.Warningf("spn/captain: failed to establish lane to %s: %s", connectTo.Hub, tErr) } else { diff --git a/spn/captain/op_gossip.go b/spn/captain/op_gossip.go index 218e70e1c..4e6866044 100644 --- a/spn/captain/op_gossip.go +++ b/spn/captain/op_gossip.go @@ -128,7 +128,7 @@ func (op *GossipOp) Deliver(msg *terminal.Msg) *terminal.Error { } // Import and verify. - h, forward, tErr := docks.ImportAndVerifyHubInfo(module.Ctx, "", announcementData, statusData, conf.MainMapName, conf.MainMapScope) + h, forward, tErr := docks.ImportAndVerifyHubInfo(module.mgr.Ctx(), "", announcementData, statusData, conf.MainMapName, conf.MainMapScope) if tErr != nil { if tErr.Is(hub.ErrOldData) { log.Debugf("spn/captain: ignoring old %s from %s", gossipMsgType, op.craneID) diff --git a/spn/captain/op_gossip_query.go b/spn/captain/op_gossip_query.go index fbda083fb..bc7e6b7e7 100644 --- a/spn/captain/op_gossip_query.go +++ b/spn/captain/op_gossip_query.go @@ -8,6 +8,7 @@ import ( "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/docks" "github.com/safing/portmaster/spn/hub" @@ -63,12 +64,12 @@ func runGossipQueryOp(t terminal.Terminal, opID uint32, data *container.Containe op.ctx, op.cancelCtx = context.WithCancel(t.Ctx()) op.InitOperationBase(t, opID) - module.StartWorker("gossip query handler", op.handler) + module.mgr.Go("gossip query handler", op.handler) return op, nil } -func (op *GossipQueryOp) handler(_ context.Context) error { +func (op *GossipQueryOp) handler(_ *mgr.WorkerCtx) error { tErr := op.sendMsgs(hub.MsgTypeAnnouncement) if tErr != nil { op.Stop(op, tErr) @@ -166,7 +167,7 @@ func (op *GossipQueryOp) Deliver(msg *terminal.Msg) *terminal.Error { } // Import and verify. - h, forward, tErr := docks.ImportAndVerifyHubInfo(module.Ctx, "", announcementData, statusData, conf.MainMapName, conf.MainMapScope) + h, forward, tErr := docks.ImportAndVerifyHubInfo(module.mgr.Ctx(), "", announcementData, statusData, conf.MainMapName, conf.MainMapScope) if tErr != nil { log.Warningf("spn/captain: failed to import %s from gossip query: %s", gossipMsgType, tErr) } else { diff --git a/spn/captain/op_publish.go b/spn/captain/op_publish.go index 3a377df8b..c1fd29e70 100644 --- a/spn/captain/op_publish.go +++ b/spn/captain/op_publish.go @@ -85,7 +85,7 @@ func runPublishOp(t terminal.Terminal, opID uint32, data *container.Container) ( if err != nil { return nil, terminal.ErrMalformedData.With("failed to get status: %w", err) } - h, forward, tErr := docks.ImportAndVerifyHubInfo(module.Ctx, "", announcementData, statusData, conf.MainMapName, conf.MainMapScope) + h, forward, tErr := docks.ImportAndVerifyHubInfo(module.mgr.Ctx(), "", announcementData, statusData, conf.MainMapName, conf.MainMapScope) if tErr != nil { return nil, tErr.Wrap("failed to import and verify hub") } diff --git a/spn/captain/piers.go b/spn/captain/piers.go index c631e201b..6e639e421 100644 --- a/spn/captain/piers.go +++ b/spn/captain/piers.go @@ -7,6 +7,7 @@ import ( "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/intel" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/network/netutils" "github.com/safing/portmaster/service/profile/endpoints" "github.com/safing/portmaster/spn/docks" @@ -45,7 +46,7 @@ func startPiers() error { } // Start worker to handle docking requests. - module.StartServiceWorker("docking request handler", 0, dockingRequestHandler) + module.mgr.Go("docking request handler", dockingRequestHandler) return nil } @@ -56,7 +57,7 @@ func stopPiers() { } } -func dockingRequestHandler(ctx context.Context) error { +func dockingRequestHandler(wc *mgr.WorkerCtx) error { // Sink all waiting ships when this worker ends. // But don't be destructive so the service worker could recover. defer func() { @@ -74,7 +75,7 @@ func dockingRequestHandler(ctx context.Context) error { for { select { - case <-ctx.Done(): + case <-wc.Done(): return nil case ship := <-dockingRequests: // Ignore nil ships. @@ -82,7 +83,7 @@ func dockingRequestHandler(ctx context.Context) error { continue } - if err := checkDockingPermission(ctx, ship); err != nil { + if err := checkDockingPermission(wc.Ctx(), ship); err != nil { log.Warningf("spn/captain: denied ship from %s to dock at pier %s: %s", ship.RemoteAddr(), ship.Transport().String(), err) } else { handleDockingRequest(ship) @@ -123,8 +124,8 @@ func handleDockingRequest(ship ships.Ship) { return } - module.StartWorker("start crane", func(ctx context.Context) error { - _ = crane.Start(ctx) + module.mgr.Go("start crane", func(wc *mgr.WorkerCtx) error { + _ = crane.Start(wc.Ctx()) // Crane handles errors internally. return nil }) diff --git a/spn/captain/public.go b/spn/captain/public.go index 441182d40..62e19d251 100644 --- a/spn/captain/public.go +++ b/spn/captain/public.go @@ -1,7 +1,6 @@ package captain import ( - "context" "errors" "fmt" "sort" @@ -10,7 +9,7 @@ import ( "github.com/safing/portmaster/base/database" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/metrics" - "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/cabin" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/docks" @@ -27,9 +26,6 @@ const ( var ( publicIdentity *cabin.Identity publicIdentityKey = "core:spn/public/identity" - - publicIdentityUpdateTask *modules.Task - statusUpdateTask *modules.Task ) func loadPublicIdentity() (err error) { @@ -42,7 +38,7 @@ func loadPublicIdentity() (err error) { log.Infof("spn/captain: loaded public hub identity %s", publicIdentity.Hub.ID) case errors.Is(err, database.ErrNotFound): // does not exist, create new - publicIdentity, err = cabin.CreateIdentity(module.Ctx, conf.MainMapName) + publicIdentity, err = cabin.CreateIdentity(module.mgr.Ctx(), conf.MainMapName) if err != nil { return fmt.Errorf("failed to create new identity: %w", err) } @@ -91,36 +87,22 @@ func loadPublicIdentity() (err error) { } func prepPublicIdentityMgmt() error { - publicIdentityUpdateTask = module.NewTask( - "maintain public identity", - maintainPublicIdentity, - ) + module.mgr.Repeat("maintain public status", maintainStatusInterval, maintainPublicStatus) - statusUpdateTask = module.NewTask( - "maintain public status", - maintainPublicStatus, - ).Repeat(maintainStatusInterval) - - return module.RegisterEventHook( - "config", - "config change", - "update public identity from config", - func(_ context.Context, _ interface{}) error { - // trigger update in 5 minutes - publicIdentityUpdateTask.Schedule(time.Now().Add(5 * time.Minute)) - return nil - }, - ) + module.instance.Config().EventConfigChange.AddCallback("update public identity from config", + func(wc *mgr.WorkerCtx, s struct{}) (cancel bool, err error) { + module.mgr.Delay("maintain public identity", 5*time.Minute, maintainPublicIdentity) + return false, nil + }) + return nil } // TriggerHubStatusMaintenance queues the Hub status update task to be executed. func TriggerHubStatusMaintenance() { - if statusUpdateTask != nil { - statusUpdateTask.Queue() - } + module.mgr.Go("maintain public status", maintainPublicStatus) } -func maintainPublicIdentity(ctx context.Context, task *modules.Task) error { +func maintainPublicIdentity(ctx *mgr.WorkerCtx) error { changed, err := publicIdentity.MaintainAnnouncement(nil, false) if err != nil { return fmt.Errorf("failed to maintain announcement: %w", err) @@ -146,7 +128,7 @@ func maintainPublicIdentity(ctx context.Context, task *modules.Task) error { return nil } -func maintainPublicStatus(ctx context.Context, task *modules.Task) error { +func maintainPublicStatus(ctx *mgr.WorkerCtx) error { // Get current lanes. cranes := docks.GetAllAssignedCranes() lanes := make([]*hub.Lane, 0, len(cranes)) diff --git a/spn/crew/sticky.go b/spn/crew/sticky.go index 4fcc39300..c6686c3a8 100644 --- a/spn/crew/sticky.go +++ b/spn/crew/sticky.go @@ -1,13 +1,12 @@ package crew import ( - "context" "fmt" "sync" "time" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/network" "github.com/safing/portmaster/service/network/packet" "github.com/safing/portmaster/spn/navigator" @@ -149,7 +148,7 @@ func (t *Tunnel) avoidDestinationHub() { log.Warningf("spn/crew: avoiding %s for %s", t.dstPin.Hub, ipKey) } -func cleanStickyHubs(ctx context.Context, task *modules.Task) error { +func cleanStickyHubs(ctx *mgr.WorkerCtx) error { stickyLock.Lock() defer stickyLock.Unlock() diff --git a/spn/docks/controller.go b/spn/docks/controller.go index 4e8521757..9a1beb562 100644 --- a/spn/docks/controller.go +++ b/spn/docks/controller.go @@ -80,7 +80,7 @@ func initCraneController( t.GrantPermission(terminal.IsCraneController) // Start workers. - t.StartWorkers(module, "crane controller terminal") + t.StartWorkers(module.mgr, "crane controller terminal") return cct } diff --git a/spn/docks/crane.go b/spn/docks/crane.go index 5d1894462..65e27147e 100644 --- a/spn/docks/crane.go +++ b/spn/docks/crane.go @@ -427,7 +427,7 @@ func (crane *Crane) decrypt(shipment *container.Container) (decrypted *container return container.New(decryptedData), nil } -func (crane *Crane) unloader(workerCtx context.Context) error { +func (crane *Crane) unloader(workerCtx *mgr.WorkerCtx) error { // Unclean shutdown safeguard. defer crane.Stop(terminal.ErrUnknownError.With("unloader died")) @@ -517,7 +517,7 @@ func (crane *Crane) unloadUntilFull(buf []byte) error { } } -func (crane *Crane) handler(workerCtx context.Context) error { +func (crane *Crane) handler(workerCtx *mgr.WorkerCtx) error { var partialShipment *container.Container var segmentLength uint32 @@ -646,7 +646,7 @@ handling: } } -func (crane *Crane) loader(workerCtx context.Context) (err error) { +func (crane *Crane) loader(workerCtx *mgr.WorkerCtx) (err error) { shipment := container.New() var partialShipment *container.Container var loadingTimer *time.Timer diff --git a/spn/docks/crane_establish.go b/spn/docks/crane_establish.go index 71637e456..d03896fe8 100644 --- a/spn/docks/crane_establish.go +++ b/spn/docks/crane_establish.go @@ -1,11 +1,11 @@ package docks import ( - "context" "time" "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/terminal" ) @@ -70,7 +70,7 @@ func (crane *Crane) establishTerminal(id uint32, initData *container.Container) case crane.terminalMsgs <- msg: default: // Send error async. - module.StartWorker("abandon terminal", func(ctx context.Context) error { + module.mgr.Go("abandon terminal", func(ctx *mgr.WorkerCtx) error { select { case crane.terminalMsgs <- msg: case <-ctx.Done(): diff --git a/spn/docks/crane_init.go b/spn/docks/crane_init.go index 472f9643d..807414119 100644 --- a/spn/docks/crane_init.go +++ b/spn/docks/crane_init.go @@ -71,7 +71,7 @@ func (crane *Crane) Start(callerCtx context.Context) error { } func (crane *Crane) startLocal(callerCtx context.Context) *terminal.Error { - module.StartWorker("crane unloader", crane.unloader) + module.mgr.Go("crane unloader", crane.unloader) if !crane.ship.IsSecure() { // Start encrypted channel. @@ -171,8 +171,8 @@ func (crane *Crane) startLocal(callerCtx context.Context) *terminal.Error { } // Start remaining workers. - module.StartWorker("crane loader", crane.loader) - module.StartWorker("crane handler", crane.handler) + module.mgr.Go("crane loader", crane.loader) + module.mgr.Go("crane handler", crane.handler) return nil } @@ -180,7 +180,7 @@ func (crane *Crane) startLocal(callerCtx context.Context) *terminal.Error { func (crane *Crane) startRemote(callerCtx context.Context) *terminal.Error { var initMsg *container.Container - module.StartWorker("crane unloader", crane.unloader) + module.mgr.Go("crane unloader", crane.unloader) handling: for { @@ -270,8 +270,8 @@ handling: } // Start remaining workers. - module.StartWorker("crane loader", crane.loader) - module.StartWorker("crane handler", crane.handler) + module.mgr.Go("crane loader", crane.loader) + module.mgr.Go("crane handler", crane.handler) return nil } diff --git a/spn/docks/crane_terminal.go b/spn/docks/crane_terminal.go index 7ac506092..5a6d7a53b 100644 --- a/spn/docks/crane_terminal.go +++ b/spn/docks/crane_terminal.go @@ -74,7 +74,7 @@ func initCraneTerminal( t.SetTerminalExtension(ct) // Start workers. - t.StartWorkers(module, "crane terminal") + t.StartWorkers(module.mgr, "crane terminal") return ct } diff --git a/spn/docks/hub_import.go b/spn/docks/hub_import.go index c8f46d307..ff2981337 100644 --- a/spn/docks/hub_import.go +++ b/spn/docks/hub_import.go @@ -170,7 +170,7 @@ func verifyHubIP(ctx context.Context, h *hub.Hub, ip net.IP) error { if err != nil { return fmt.Errorf("failed to create crane: %w", err) } - module.StartWorker("crane unloader", crane.unloader) + module.mgr.Go("crane unloader", crane.unloader) defer crane.Stop(nil) // Verify Hub. diff --git a/spn/docks/op_capacity.go b/spn/docks/op_capacity.go index a4ca5b5bb..9c924d97d 100644 --- a/spn/docks/op_capacity.go +++ b/spn/docks/op_capacity.go @@ -2,7 +2,6 @@ package docks import ( "bytes" - "context" "sync/atomic" "time" @@ -11,6 +10,7 @@ import ( "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/terminal" ) @@ -116,7 +116,7 @@ func NewCapacityTestOp(t terminal.Terminal, opts *CapacityTestOptions) (*Capacit } // Start handler. - module.StartWorker("op capacity handler", op.handler) + module.mgr.Go("op capacity handler", op.handler) return op, nil } @@ -157,13 +157,13 @@ func startCapacityTestOp(t terminal.Terminal, opID uint32, data *container.Conta // Start handler and sender. op.senderStarted = true - module.StartWorker("op capacity handler", op.handler) - module.StartWorker("op capacity sender", op.sender) + module.mgr.Go("op capacity handler", op.handler) + module.mgr.Go("op capacity sender", op.sender) return op, nil } -func (op *CapacityTestOp) handler(ctx context.Context) error { +func (op *CapacityTestOp) handler(ctx *mgr.WorkerCtx) error { defer capacityTestRunning.UnSet() returnErr := terminal.ErrStopping @@ -204,7 +204,7 @@ func (op *CapacityTestOp) handler(ctx context.Context) error { maxTestTimeReached = time.After(op.opts.MaxTime) if !op.senderStarted { op.senderStarted = true - module.StartWorker("op capacity sender", op.sender) + module.mgr.Go("op capacity sender", op.sender) } } @@ -241,7 +241,7 @@ func (op *CapacityTestOp) handler(ctx context.Context) error { } } -func (op *CapacityTestOp) sender(ctx context.Context) error { +func (op *CapacityTestOp) sender(ctx *mgr.WorkerCtx) error { for { // Send next chunk. msg := op.NewMsg(capacityTestSendData) diff --git a/spn/docks/op_expand.go b/spn/docks/op_expand.go index 567d6fc47..f4c69a747 100644 --- a/spn/docks/op_expand.go +++ b/spn/docks/op_expand.go @@ -9,6 +9,7 @@ import ( "github.com/tevino/abool" "github.com/safing/portmaster/base/container" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/terminal" ) @@ -201,13 +202,13 @@ func expand(t terminal.Terminal, opID uint32, data *container.Container) (termin } // Start workers. - module.StartWorker("expand op forward relay", op.forwardHandler) - module.StartWorker("expand op backward relay", op.backwardHandler) + module.mgr.Go("expand op forward relay", op.forwardHandler) + module.mgr.Go("expand op backward relay", op.backwardHandler) if op.flowControl != nil { - op.flowControl.StartWorkers(module, "expand op") + op.flowControl.StartWorkers(module.mgr, "expand op") } if op.relayTerminal.flowControl != nil { - op.relayTerminal.flowControl.StartWorkers(module, "expand op terminal") + op.relayTerminal.flowControl.StartWorkers(module.mgr, "expand op terminal") } return op, nil @@ -259,7 +260,7 @@ func (op *ExpandOp) submitBackwardUpstream(msg *terminal.Msg, timeout time.Durat } } -func (op *ExpandOp) forwardHandler(_ context.Context) error { +func (op *ExpandOp) forwardHandler(_ *mgr.WorkerCtx) error { // Metrics setup and submitting. atomic.AddInt64(activeExpandOps, 1) started := time.Now() @@ -290,7 +291,7 @@ func (op *ExpandOp) forwardHandler(_ context.Context) error { } } -func (op *ExpandOp) backwardHandler(_ context.Context) error { +func (op *ExpandOp) backwardHandler(_ *mgr.WorkerCtx) error { for { select { case msg := <-op.relayTerminal.recvProxy(): @@ -336,7 +337,7 @@ func (op *ExpandOp) HandleStop(err *terminal.Error) (errorToSend *terminal.Error // Abandon shuts down the terminal unregistering it from upstream and calling HandleAbandon(). func (t *ExpansionRelayTerminal) Abandon(err *terminal.Error) { if t.abandoning.SetToIf(false, true) { - module.StartWorker("terminal abandon procedure", func(_ context.Context) error { + module.mgr.Go("terminal abandon procedure", func(_ *mgr.WorkerCtx) error { t.handleAbandonProcedure(err) return nil }) diff --git a/spn/docks/op_latency.go b/spn/docks/op_latency.go index 12e9b75eb..7e2c19339 100644 --- a/spn/docks/op_latency.go +++ b/spn/docks/op_latency.go @@ -2,7 +2,6 @@ package docks import ( "bytes" - "context" "fmt" "time" @@ -10,6 +9,7 @@ import ( "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/rng" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/terminal" ) @@ -82,12 +82,12 @@ func NewLatencyTestOp(t terminal.Terminal) (*LatencyTestClientOp, *terminal.Erro } // Start handler. - module.StartWorker("op latency handler", op.handler) + module.mgr.Go("op latency handler", op.handler) return op, nil } -func (op *LatencyTestClientOp) handler(ctx context.Context) error { +func (op *LatencyTestClientOp) handler(ctx *mgr.WorkerCtx) error { returnErr := terminal.ErrStopping defer func() { // Linters don't get that returnErr is used when directly used as defer. diff --git a/spn/docks/op_sync_state.go b/spn/docks/op_sync_state.go index e6f964611..72bb04f3e 100644 --- a/spn/docks/op_sync_state.go +++ b/spn/docks/op_sync_state.go @@ -6,6 +6,7 @@ import ( "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/terminal" ) @@ -39,8 +40,8 @@ func init() { // startSyncStateOp starts a worker that runs the sync state operation. func (crane *Crane) startSyncStateOp() { - module.StartWorker("sync crane state", func(ctx context.Context) error { - tErr := crane.Controller.SyncState(ctx) + module.mgr.Go("sync crane state", func(wc *mgr.WorkerCtx) error { + tErr := crane.Controller.SyncState(wc.Ctx()) if tErr != nil { return tErr } diff --git a/spn/docks/terminal_expansion.go b/spn/docks/terminal_expansion.go index 442e9bf5d..a04f93cb3 100644 --- a/spn/docks/terminal_expansion.go +++ b/spn/docks/terminal_expansion.go @@ -53,7 +53,7 @@ func ExpandTo(from terminal.Terminal, routeTo string, encryptFor *hub.Hub) (*Exp // Create base terminal for expansion. base, initData, tErr := terminal.NewLocalBaseTerminal( - module.Ctx, + module.mgr.Ctx(), 0, // Ignore; The ID of the operation is used for communication. from.FmtID(), encryptFor, @@ -81,7 +81,7 @@ func ExpandTo(from terminal.Terminal, routeTo string, encryptFor *hub.Hub) (*Exp } // Start Workers. - base.StartWorkers(module, "expansion terminal") + base.StartWorkers(module.mgr, "expansion terminal") return expansion, nil } diff --git a/spn/hub/hub_test.go b/spn/hub/hub_test.go index 8bd14ce90..d715653e8 100644 --- a/spn/hub/hub_test.go +++ b/spn/hub/hub_test.go @@ -6,7 +6,6 @@ import ( "github.com/stretchr/testify/assert" - "github.com/safing/portmaster/base/modules" _ "github.com/safing/portmaster/service/core/base" "github.com/safing/portmaster/service/core/pmtesting" ) diff --git a/spn/navigator/api.go b/spn/navigator/api.go index 2da6bfa91..fe695570c 100644 --- a/spn/navigator/api.go +++ b/spn/navigator/api.go @@ -52,7 +52,6 @@ func registerAPIEndpoints() error { if err := api.RegisterEndpoint(api.Endpoint{ Path: `spn/map/{map:[A-Za-z0-9]{1,255}}/pins`, Read: api.PermitUser, - BelongsTo: module, StructFunc: handleMapPinsRequest, Name: "Get SPN map pins", Description: "Returns a list of pins on the map.", @@ -63,7 +62,6 @@ func registerAPIEndpoints() error { if err := api.RegisterEndpoint(api.Endpoint{ Path: `spn/map/{map:[A-Za-z0-9]{1,255}}/intel/update`, Write: api.PermitSelf, - BelongsTo: module, ActionFunc: handleIntelUpdateRequest, Name: "Update map intelligence.", Description: "Updates the intel data of the map.", @@ -74,7 +72,6 @@ func registerAPIEndpoints() error { if err := api.RegisterEndpoint(api.Endpoint{ Path: `spn/map/{map:[A-Za-z0-9]{1,255}}/optimization`, Read: api.PermitUser, - BelongsTo: module, StructFunc: handleMapOptimizationRequest, Name: "Get SPN map optimization", Description: "Returns the calculated optimization for the map.", @@ -85,7 +82,6 @@ func registerAPIEndpoints() error { if err := api.RegisterEndpoint(api.Endpoint{ Path: `spn/map/{map:[A-Za-z0-9]{1,255}}/optimization/table`, Read: api.PermitUser, - BelongsTo: module, DataFunc: handleMapOptimizationTableRequest, Name: "Get SPN map optimization as a table", Description: "Returns the calculated optimization for the map as a table.", @@ -96,7 +92,6 @@ func registerAPIEndpoints() error { if err := api.RegisterEndpoint(api.Endpoint{ Path: `spn/map/{map:[A-Za-z0-9]{1,255}}/measurements`, Read: api.PermitUser, - BelongsTo: module, StructFunc: handleMapMeasurementsRequest, Name: "Get SPN map measurements", Description: "Returns the measurements of the map.", @@ -108,7 +103,6 @@ func registerAPIEndpoints() error { Path: `spn/map/{map:[A-Za-z0-9]{1,255}}/measurements/table`, MimeType: api.MimeTypeText, Read: api.PermitUser, - BelongsTo: module, DataFunc: handleMapMeasurementsTableRequest, Name: "Get SPN map measurements as a table", Description: "Returns the measurements of the map as a table.", @@ -119,7 +113,6 @@ func registerAPIEndpoints() error { if err := api.RegisterEndpoint(api.Endpoint{ Path: `spn/map/{map:[A-Za-z0-9]{1,255}}/graph{format:\.[a-z]{2,4}}`, Read: api.PermitUser, - BelongsTo: module, HandlerFunc: handleMapGraphRequest, Name: "Get SPN map graph", Description: "Returns a graph of the given SPN map.", diff --git a/spn/navigator/api_route.go b/spn/navigator/api_route.go index bae4a27fc..965cb5856 100644 --- a/spn/navigator/api_route.go +++ b/spn/navigator/api_route.go @@ -25,7 +25,6 @@ func registerRouteAPIEndpoints() error { if err := api.RegisterEndpoint(api.Endpoint{ Path: `spn/map/{map:[A-Za-z0-9]{1,255}}/route/to/{destination:[a-z0-9_\.:-]{1,255}}`, Read: api.PermitUser, - BelongsTo: module, ActionFunc: handleRouteCalculationRequest, Name: "Calculate Route through SPN", Description: "Returns a textual representation of the routing process.", diff --git a/spn/navigator/database.go b/spn/navigator/database.go index 7a4a88b3e..62288951b 100644 --- a/spn/navigator/database.go +++ b/spn/navigator/database.go @@ -1,7 +1,6 @@ package navigator import ( - "context" "fmt" "strings" @@ -10,6 +9,7 @@ import ( "github.com/safing/portmaster/base/database/query" "github.com/safing/portmaster/base/database/record" "github.com/safing/portmaster/base/database/storage" + "github.com/safing/portmaster/service/mgr" ) var mapDBController *database.Controller @@ -82,7 +82,7 @@ func (s *StorageInterface) Query(q *query.Query, local, internal bool) (*iterato // Start query worker. it := iterator.New() - module.StartWorker("map query", func(_ context.Context) error { + module.mgr.Go("map query", func(_ *mgr.WorkerCtx) error { s.processQuery(m, q, it) return nil }) @@ -131,10 +131,10 @@ func withdrawMapDatabase() { // PushPinChanges pushes all changed pins to subscribers. func (m *Map) PushPinChanges() { - module.StartWorker("push pin changes", m.pushPinChangesWorker) + module.mgr.Go("push pin changes", m.pushPinChangesWorker) } -func (m *Map) pushPinChangesWorker(ctx context.Context) error { +func (m *Map) pushPinChangesWorker(ctx *mgr.WorkerCtx) error { m.RLock() defer m.RUnlock() @@ -155,7 +155,7 @@ func (pin *Pin) pushChange() { } // Start worker to push changes. - module.StartWorker("push pin change", func(ctx context.Context) error { + module.mgr.Go("push pin change", func(ctx *mgr.WorkerCtx) error { if pin.pushChanges.SetToIf(true, false) { mapDBController.PushUpdate(pin.Export()) } diff --git a/spn/navigator/measurements.go b/spn/navigator/measurements.go index f2784d1e0..2fd20abdd 100644 --- a/spn/navigator/measurements.go +++ b/spn/navigator/measurements.go @@ -6,7 +6,6 @@ import ( "time" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/spn/docks" "github.com/safing/portmaster/spn/terminal" ) @@ -25,7 +24,7 @@ const ( // 1000c -> 100h -> capped to 50h. ) -func (m *Map) measureHubs(ctx context.Context, _ *modules.Task) error { +func (m *Map) measureHubs(ctx context.Context) error { if home, _ := m.GetHome(); home == nil { log.Debug("spn/navigator: skipping measuring, no home hub set") return nil diff --git a/spn/navigator/update.go b/spn/navigator/update.go index b04a58691..dfe0c54c6 100644 --- a/spn/navigator/update.go +++ b/spn/navigator/update.go @@ -15,7 +15,6 @@ import ( "github.com/safing/portmaster/base/database/query" "github.com/safing/portmaster/base/database/record" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/base/utils" "github.com/safing/portmaster/service/intel/geoip" "github.com/safing/portmaster/service/mgr" @@ -421,7 +420,7 @@ func (m *Map) ResetFailingStates(ctx context.Context) { m.PushPinChanges() } -func (m *Map) updateFailingStates(ctx context.Context, task *modules.Task) error { +func (m *Map) updateFailingStates(ctx *mgr.WorkerCtx) error { m.Lock() defer m.Unlock() @@ -434,7 +433,7 @@ func (m *Map) updateFailingStates(ctx context.Context, task *modules.Task) error return nil } -func (m *Map) updateStates(ctx context.Context, task *modules.Task) error { +func (m *Map) updateStates(ctx *mgr.WorkerCtx) error { var toDelete []string m.Lock() @@ -459,7 +458,7 @@ pinLoop: // Delete hubs async, as deleting triggers a couple hooks that lock the map. if len(toDelete) > 0 { - module.StartWorker("delete hubs", func(_ context.Context) error { + module.mgr.Go("delete hubs", func(_ *mgr.WorkerCtx) error { for _, idToDelete := range toDelete { err := hub.RemoveHubAndMsgs(m.Name, idToDelete) if err != nil { diff --git a/spn/patrol/http.go b/spn/patrol/http.go index e1396fde6..c3bb14bdd 100644 --- a/spn/patrol/http.go +++ b/spn/patrol/http.go @@ -10,7 +10,7 @@ import ( "github.com/tevino/abool" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/conf" ) @@ -22,9 +22,9 @@ func HTTPSConnectivityConfirmed() bool { return httpsConnectivityConfirmed.IsSet() } -func connectivityCheckTask(ctx context.Context, task *modules.Task) error { +func connectivityCheckTask(wc *mgr.WorkerCtx) error { // Start tracing logs. - ctx, tracer := log.AddTracer(ctx) + ctx, tracer := log.AddTracer(wc.Ctx()) defer tracer.Submit() // Run checks and report status. @@ -32,14 +32,14 @@ func connectivityCheckTask(ctx context.Context, task *modules.Task) error { if success { tracer.Info("spn/patrol: all connectivity checks succeeded") if httpsConnectivityConfirmed.SetToIf(false, true) { - module.TriggerEvent(ChangeSignalEventName, nil) + module.EventChangeSignal.Submit(struct{}{}) } return nil } tracer.Errorf("spn/patrol: connectivity check failed") if httpsConnectivityConfirmed.SetToIf(true, false) { - module.TriggerEvent(ChangeSignalEventName, nil) + module.EventChangeSignal.Submit(struct{}{}) } return nil } diff --git a/spn/patrol/module.go b/spn/patrol/module.go index 0962f17fa..414c06165 100644 --- a/spn/patrol/module.go +++ b/spn/patrol/module.go @@ -13,17 +13,18 @@ import ( const ChangeSignalEventName = "change signal" type Patrol struct { + mgr *mgr.Manager instance instance EventChangeSignal *mgr.EventMgr[struct{}] } func (p *Patrol) Start(m *mgr.Manager) error { + p.mgr = m p.EventChangeSignal = mgr.NewEventMgr[struct{}](ChangeSignalEventName, m) if conf.PublicHub() { - module.NewTask("connectivity test", connectivityCheckTask). - Repeat(5 * time.Minute) + m.Repeat("connectivity test", 5*time.Minute, connectivityCheckTask) } return nil } diff --git a/spn/ships/http_shared.go b/spn/ships/http_shared.go index 3ebc49485..eddbea232 100644 --- a/spn/ships/http_shared.go +++ b/spn/ships/http_shared.go @@ -10,6 +10,7 @@ import ( "time" "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/conf" ) @@ -97,7 +98,7 @@ func addHTTPHandler(port uint16, path string, handler http.HandlerFunc) error { IdleTimeout: 1 * time.Minute, MaxHeaderBytes: 4096, // ErrorLog: &log.Logger{}, // FIXME - BaseContext: func(net.Listener) context.Context { return module.Ctx }, + BaseContext: func(net.Listener) context.Context { return module.mgr.Ctx() }, } shared.server = server @@ -123,9 +124,9 @@ func addHTTPHandler(port uint16, path string, handler http.HandlerFunc) error { // Start servers in service workers. for _, listener := range listeners { serviceListener := listener - module.StartServiceWorker( - fmt.Sprintf("shared http server listener on %s", listener.Addr()), 0, - func(ctx context.Context) error { + module.mgr.Go( + fmt.Sprintf("shared http server listener on %s", listener.Addr()), + func(_ *mgr.WorkerCtx) error { err := shared.server.Serve(serviceListener) if !errors.Is(http.ErrServerClosed, err) { return err diff --git a/spn/ships/module.go b/spn/ships/module.go index 27b1fbc3b..1bdd8c95b 100644 --- a/spn/ships/module.go +++ b/spn/ships/module.go @@ -9,10 +9,12 @@ import ( ) type Ships struct { + mgr *mgr.Manager instance instance } func (s *Ships) Start(m *mgr.Manager) error { + s.mgr = m if conf.PublicHub() { initPageInput() } diff --git a/spn/ships/tcp.go b/spn/ships/tcp.go index ffc6c6979..e6ee00d2f 100644 --- a/spn/ships/tcp.go +++ b/spn/ships/tcp.go @@ -7,6 +7,7 @@ import ( "time" "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/hub" ) @@ -81,7 +82,7 @@ func establishTCPPier(transport *hub.Transport, dockingRequests chan Ship) (Pier } // Create new pier. - pierCtx, cancelCtx := context.WithCancel(module.Ctx) + pierCtx, cancelCtx := context.WithCancel(module.mgr.Ctx()) pier := &TCPPier{ PierBase: PierBase{ transport: transport, @@ -95,9 +96,8 @@ func establishTCPPier(transport *hub.Transport, dockingRequests chan Ship) (Pier // Start workers. for _, listener := range pier.listeners { - serviceListener := listener - module.StartServiceWorker("accept TCP docking requests", 0, func(ctx context.Context) error { - return pier.dockingWorker(ctx, serviceListener) + module.mgr.Go("accept TCP docking requests", func(wc *mgr.WorkerCtx) error { + return pier.dockingWorker(wc.Ctx(), listener) }) } diff --git a/spn/sluice/packet_listener.go b/spn/sluice/packet_listener.go index 3eb64cbbb..b3c1f026d 100644 --- a/spn/sluice/packet_listener.go +++ b/spn/sluice/packet_listener.go @@ -1,13 +1,13 @@ package sluice import ( - "context" "io" "net" "sync" "sync/atomic" "time" + "github.com/safing/portmaster/service/mgr" "github.com/tevino/abool" ) @@ -37,8 +37,8 @@ func ListenPacket(network, address string) (net.Listener, error) { newConns: make(chan *PacketConn), conns: make(map[string]*PacketConn), } - module.StartServiceWorker("packet listener reader", 0, ln.reader) - module.StartServiceWorker("packet listener cleaner", time.Minute, ln.cleaner) + module.mgr.Go("packet listener reader", ln.reader) + module.mgr.Go("packet listener cleaner", ln.cleaner) return ln, nil } @@ -99,7 +99,7 @@ func (ln *PacketListener) setConn(conn *PacketConn) { ln.conns[conn.addr.String()] = conn } -func (ln *PacketListener) reader(_ context.Context) error { +func (ln *PacketListener) reader(_ *mgr.WorkerCtx) error { for { // Read data from connection. buf := make([]byte, 512) @@ -145,7 +145,7 @@ func (ln *PacketListener) reader(_ context.Context) error { } } -func (ln *PacketListener) cleaner(ctx context.Context) error { +func (ln *PacketListener) cleaner(ctx *mgr.WorkerCtx) error { for { select { case <-time.After(1 * time.Minute): diff --git a/spn/sluice/sluice.go b/spn/sluice/sluice.go index 6a3249f90..8c993c3ac 100644 --- a/spn/sluice/sluice.go +++ b/spn/sluice/sluice.go @@ -1,7 +1,6 @@ package sluice import ( - "context" "fmt" "net" "strconv" @@ -9,6 +8,7 @@ import ( "time" "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" ) @@ -46,9 +46,8 @@ func StartSluice(network, address string) { } // Start service worker. - module.StartServiceWorker( + module.mgr.Go( s.network+" sluice listener", - 10*time.Second, s.listenHandler, ) } @@ -189,7 +188,7 @@ func (s *Sluice) handleConnection(conn net.Conn) { success = true } -func (s *Sluice) listenHandler(_ context.Context) error { +func (s *Sluice) listenHandler(_ *mgr.WorkerCtx) error { defer s.abandon() err := s.init() if err != nil { @@ -201,7 +200,7 @@ func (s *Sluice) listenHandler(_ context.Context) error { for { conn, err := s.listener.Accept() if err != nil { - if module.IsStopping() { + if module.mgr.IsDone() { return nil } return fmt.Errorf("failed to accept connection: %w", err) diff --git a/spn/sluice/udp_listener.go b/spn/sluice/udp_listener.go index 4065d5205..31f83e077 100644 --- a/spn/sluice/udp_listener.go +++ b/spn/sluice/udp_listener.go @@ -1,7 +1,6 @@ package sluice import ( - "context" "io" "net" "runtime" @@ -9,6 +8,7 @@ import ( "sync/atomic" "time" + "github.com/safing/portmaster/service/mgr" "github.com/tevino/abool" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" @@ -64,8 +64,8 @@ func ListenUDP(network, address string) (net.Listener, error) { } // Start workers. - module.StartServiceWorker("udp listener reader", 0, ln.reader) - module.StartServiceWorker("udp listener cleaner", time.Minute, ln.cleaner) + module.mgr.Go("udp listener reader", ln.reader) + module.mgr.Go("udp listener cleaner", ln.cleaner) return ln, nil } @@ -126,7 +126,7 @@ func (ln *UDPListener) setConn(conn *UDPConn) { ln.conns[conn.addr.String()] = conn } -func (ln *UDPListener) reader(_ context.Context) error { +func (ln *UDPListener) reader(_ *mgr.WorkerCtx) error { for { // TODO: Find good buf size. // With a buf size of 512 we have seen this error on Windows: @@ -180,7 +180,7 @@ func (ln *UDPListener) reader(_ context.Context) error { } } -func (ln *UDPListener) cleaner(ctx context.Context) error { +func (ln *UDPListener) cleaner(ctx *mgr.WorkerCtx) error { for { select { case <-time.After(1 * time.Minute): diff --git a/spn/terminal/control_flow.go b/spn/terminal/control_flow.go index 24685206a..af328fe88 100644 --- a/spn/terminal/control_flow.go +++ b/spn/terminal/control_flow.go @@ -8,7 +8,6 @@ import ( "time" "github.com/safing/portmaster/base/formats/varint" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/service/mgr" ) @@ -19,7 +18,7 @@ type FlowControl interface { Send(msg *Msg, timeout time.Duration) *Error ReadyToSend() <-chan struct{} Flush(timeout time.Duration) - StartWorkers(m *modules.Module, terminalName string) + StartWorkers(m *mgr.Manager, terminalName string) RecvQueueLen() int SendQueueLen() int } @@ -122,8 +121,8 @@ func NewDuplexFlowQueue( } // StartWorkers starts the necessary workers to operate the flow queue. -func (dfq *DuplexFlowQueue) StartWorkers(m *modules.Module, terminalName string) { - m.StartWorker(terminalName+" flow queue", dfq.FlowHandler) +func (dfq *DuplexFlowQueue) StartWorkers(m *mgr.Manager, terminalName string) { + m.Go(terminalName+" flow queue", dfq.FlowHandler) } // shouldReportRecvSpace returns whether the receive space should be reported. diff --git a/spn/terminal/operation_counter.go b/spn/terminal/operation_counter.go index 1a732ce86..1609a9455 100644 --- a/spn/terminal/operation_counter.go +++ b/spn/terminal/operation_counter.go @@ -1,7 +1,6 @@ package terminal import ( - "context" "fmt" "sync" "time" @@ -10,6 +9,7 @@ import ( "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" ) // CounterOpType is the type ID for the Counter Operation. @@ -68,7 +68,7 @@ func NewCounterOp(t Terminal, opts CounterOpts) (*CounterOp, *Error) { // Start worker if needed. if op.getRemoteCounterTarget() > 0 && !op.opts.suppressWorker { - module.StartWorker("counter sender", op.CounterWorker) + module.mgr.Go("counter sender", op.CounterWorker) } return op, nil } @@ -91,7 +91,7 @@ func startCounterOp(t Terminal, opID uint32, data *container.Container) (Operati // Start worker if needed. if op.getRemoteCounterTarget() > 0 { - module.StartWorker("counter sender", op.CounterWorker) + module.mgr.Go("counter sender", op.CounterWorker) } return op, nil @@ -219,7 +219,7 @@ func (op *CounterOp) Wait() { } // CounterWorker is a worker that sends counters. -func (op *CounterOp) CounterWorker(ctx context.Context) error { +func (op *CounterOp) CounterWorker(ctx *mgr.WorkerCtx) error { for { // Send counter msg. err := op.SendCounter() diff --git a/spn/terminal/terminal.go b/spn/terminal/terminal.go index dc0587564..e4b044f9e 100644 --- a/spn/terminal/terminal.go +++ b/spn/terminal/terminal.go @@ -231,10 +231,10 @@ func (t *TerminalBase) Deliver(msg *Msg) *Error { } // StartWorkers starts the necessary workers to operate the Terminal. -func (t *TerminalBase) StartWorkers(m *modules.Module, terminalName string) { +func (t *TerminalBase) StartWorkers(m *mgr.Manager, terminalName string) { // Start terminal workers. - m.StartWorker(terminalName+" handler", t.Handler) - m.StartWorker(terminalName+" sender", t.Sender) + m.Go(terminalName+" handler", t.Handler) + m.Go(terminalName+" sender", t.Sender) // Start any flow control workers. if t.flowControl != nil { @@ -250,7 +250,7 @@ const ( // Handler receives and handles messages and must be started as a worker in the // module where the Terminal is used. -func (t *TerminalBase) Handler(_ context.Context) error { +func (t *TerminalBase) Handler(_ *mgr.WorkerCtx) error { defer t.Abandon(ErrInternalError.With("handler died")) var msg *Msg @@ -322,7 +322,7 @@ func (t *TerminalBase) submitToUpstream(msg *Msg, timeout time.Duration) { // Sender handles sending messages and must be started as a worker in the // module where the Terminal is used. -func (t *TerminalBase) Sender(_ context.Context) error { +func (t *TerminalBase) Sender(_ *mgr.WorkerCtx) error { // Don't send messages, if the encryption is net yet set up. // The server encryption session is only initialized with the first // operative message, not on Terminal creation. @@ -782,7 +782,7 @@ func (t *TerminalBase) sendOpMsgs(msg *Msg) *Error { // Should not be overridden by implementations. func (t *TerminalBase) Abandon(err *Error) { if t.Abandoning.SetToIf(false, true) { - module.StartWorker("terminal abandon procedure", func(_ context.Context) error { + module.mgr.Go("terminal abandon procedure", func(_ *mgr.WorkerCtx) error { t.handleAbandonProcedure(err) return nil }) diff --git a/spn/terminal/testing.go b/spn/terminal/testing.go index 1c59de690..eca8e9fd8 100644 --- a/spn/terminal/testing.go +++ b/spn/terminal/testing.go @@ -35,7 +35,7 @@ func NewLocalTestTerminal( if err != nil { return nil, nil, err } - t.StartWorkers(module, "test terminal") + // t.StartWorkers(module, "test terminal") return &TestTerminal{t}, initData, nil } @@ -54,7 +54,7 @@ func NewRemoteTestTerminal( if err != nil { return nil, nil, err } - t.StartWorkers(module, "test terminal") + // t.StartWorkers(module, "test terminal") return &TestTerminal{t}, initMsg, nil } @@ -138,36 +138,36 @@ func (t *TestTerminal) HandleAbandon(err *Error) (errorToSend *Error) { // NewSimpleTestTerminalPair provides a simple conntected terminal pair for tests. func NewSimpleTestTerminalPair(delay time.Duration, delayQueueSize int, opts *TerminalOpts) (a, b *TestTerminal, err error) { - if opts == nil { - opts = &TerminalOpts{ - Padding: defaultTestPadding, - FlowControl: FlowControlDFQ, - FlowControlSize: defaultTestQueueSize, - } - } - - var initData *container.Container - var tErr *Error - a, initData, tErr = NewLocalTestTerminal( - module.Ctx, 127, "a", nil, opts, UpstreamSendFunc(createDelayingTestForwardingFunc( - "a", "b", delay, delayQueueSize, func(msg *Msg, timeout time.Duration) *Error { - return b.Deliver(msg) - }, - )), - ) - if tErr != nil { - return nil, nil, tErr.Wrap("failed to create local test terminal") - } - b, _, tErr = NewRemoteTestTerminal( - module.Ctx, 127, "b", nil, initData, UpstreamSendFunc(createDelayingTestForwardingFunc( - "b", "a", delay, delayQueueSize, func(msg *Msg, timeout time.Duration) *Error { - return a.Deliver(msg) - }, - )), - ) - if tErr != nil { - return nil, nil, tErr.Wrap("failed to create remote test terminal") - } + // if opts == nil { + // opts = &TerminalOpts{ + // Padding: defaultTestPadding, + // FlowControl: FlowControlDFQ, + // FlowControlSize: defaultTestQueueSize, + // } + // } + + // var initData *container.Container + // var tErr *Error + // a, initData, tErr = NewLocalTestTerminal( + // module.Ctx, 127, "a", nil, opts, UpstreamSendFunc(createDelayingTestForwardingFunc( + // "a", "b", delay, delayQueueSize, func(msg *Msg, timeout time.Duration) *Error { + // return b.Deliver(msg) + // }, + // )), + // ) + // if tErr != nil { + // return nil, nil, tErr.Wrap("failed to create local test terminal") + // } + // b, _, tErr = NewRemoteTestTerminal( + // module.Ctx, 127, "b", nil, initData, UpstreamSendFunc(createDelayingTestForwardingFunc( + // "b", "a", delay, delayQueueSize, func(msg *Msg, timeout time.Duration) *Error { + // return a.Deliver(msg) + // }, + // )), + // ) + // if tErr != nil { + // return nil, nil, tErr.Wrap("failed to create remote test terminal") + // } return a, b, nil } diff --git a/spn/unit/scheduler.go b/spn/unit/scheduler.go index 1db0c5014..9027241fb 100644 --- a/spn/unit/scheduler.go +++ b/spn/unit/scheduler.go @@ -1,7 +1,6 @@ package unit import ( - "context" "errors" "math" "sync" From 89e439904672c6b228c2a2ed1e3b374040497584 Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov Date: Mon, 24 Jun 2024 11:46:50 +0300 Subject: [PATCH 11/56] [WIP] add new task system to module manager --- base/api/main.go | 2 +- base/api/module.go | 2 +- base/database/dbmodule/maintenance.go | 6 +- service/broadcasts/module.go | 2 +- service/compat/module.go | 17 +-- service/core/base/logs.go | 5 +- service/intel/customlists/module.go | 21 ++-- service/intel/geoip/database.go | 2 +- service/mgr/task.go | 164 ++++++++++++++++++++++++++ service/mgr/worker.go | 132 +++------------------ service/netquery/module_api.go | 11 +- service/profile/config-update.go | 6 +- service/resolver/failing.go | 6 +- service/resolver/main.go | 9 +- service/resolver/metrics.go | 2 +- service/updates/main.go | 31 ++--- service/updates/module.go | 2 + service/updates/upgrader.go | 4 +- spn/access/module.go | 26 ++-- spn/captain/hooks.go | 2 +- spn/captain/module.go | 9 +- spn/captain/public.go | 6 +- spn/crew/module.go | 2 +- spn/navigator/measurements.go | 6 +- spn/navigator/module.go | 15 +-- spn/patrol/module.go | 2 +- 26 files changed, 259 insertions(+), 233 deletions(-) create mode 100644 service/mgr/task.go diff --git a/base/api/main.go b/base/api/main.go index 26b79350d..d50c643cb 100644 --- a/base/api/main.go +++ b/base/api/main.go @@ -64,7 +64,7 @@ func start() error { // start api auth token cleaner if authFnSet.IsSet() { - module.mgr.Repeat("clean api sessions", 5*time.Minute, cleanSessions) + _ = module.mgr.Repeat("clean api sessions", 5*time.Minute, cleanSessions, nil) } return registerEndpointBridgeDB() diff --git a/base/api/module.go b/base/api/module.go index af5ea3a1f..01ea043ac 100644 --- a/base/api/module.go +++ b/base/api/module.go @@ -32,7 +32,7 @@ func (api *API) Start(m *mgr.Manager) error { // Stop stops the module. func (api *API) Stop(_ *mgr.Manager) error { - return start() + return stop() } var ( diff --git a/base/database/dbmodule/maintenance.go b/base/database/dbmodule/maintenance.go index 64704460a..a899f4077 100644 --- a/base/database/dbmodule/maintenance.go +++ b/base/database/dbmodule/maintenance.go @@ -9,9 +9,9 @@ import ( ) func startMaintenanceTasks() { - module.mgr.Repeat("basic maintenance", 10*time.Minute, maintainBasic) - module.mgr.Repeat("thorough maintenance", 1*time.Hour, maintainThorough) - module.mgr.Repeat("record maintenance", 1*time.Hour, maintainRecords) + _ = module.mgr.Repeat("basic maintenance", 10*time.Minute, maintainBasic, nil) + _ = module.mgr.Repeat("thorough maintenance", 1*time.Hour, maintainThorough, nil) + _ = module.mgr.Repeat("record maintenance", 1*time.Hour, maintainRecords, nil) } func maintainBasic(ctx *mgr.WorkerCtx) error { diff --git a/service/broadcasts/module.go b/service/broadcasts/module.go index 1f6d3882b..2ef89d948 100644 --- a/service/broadcasts/module.go +++ b/service/broadcasts/module.go @@ -55,7 +55,7 @@ func start() error { // Start broadcast notifier task. startOnce.Do(func() { - module.mgr.Repeat("broadcast notifier", 10*time.Minute, broadcastNotify) + module.mgr.Repeat("broadcast notifier", 10*time.Minute, broadcastNotify, nil) }) return nil diff --git a/service/compat/module.go b/service/compat/module.go index 23a230e6a..1efaa690c 100644 --- a/service/compat/module.go +++ b/service/compat/module.go @@ -16,6 +16,8 @@ import ( type Compat struct { mgr *mgr.Manager instance instance + + selfcheckTask *mgr.Task } // Start starts the module. @@ -54,8 +56,6 @@ var ( const selfcheckFailThreshold = 10 func init() { - // module = modules.Register("compat", prep, start, stop, "base", "network", "interception", "netenv", "notifications") - // Workaround resolver integration. // See resolver/compat.go for details. resolver.CompatDNSCheckInternalDomainScope = DNSCheckInternalDomainScope @@ -71,15 +71,11 @@ func start() error { startNotify() selfcheckNetworkChangedFlag.Refresh() - module.mgr.Repeat("compatibility self-check", 5*time.Minute, selfcheckTaskFunc) - // selfcheckTask = module.NewTask("compatibility self-check", selfcheckTaskFunc). - // Repeat(5 * time.Minute). - // MaxDelay(selfcheckTaskRetryAfter). - // Schedule(time.Now().Add(selfcheckTaskRetryAfter)) + module.selfcheckTask = module.mgr.Repeat("compatibility self-check", 5*time.Minute, selfcheckTaskFunc, nil).Delay(selfcheckTaskRetryAfter) - module.mgr.Repeat("clean notify thresholds", 1*time.Hour, cleanNotifyThreshold) + _ = module.mgr.Repeat("clean notify thresholds", 1*time.Hour, cleanNotifyThreshold, nil) module.instance.NetEnv().EventNetworkChange.AddCallback("trigger compat self-check", func(_ *mgr.WorkerCtx, _ struct{}) (bool, error) { - module.mgr.Delay("trigger compat self-check", selfcheckTaskRetryAfter, selfcheckTaskFunc) + module.selfcheckTask.Delay(selfcheckTaskRetryAfter) return false, nil }) return nil @@ -126,8 +122,7 @@ func selfcheckTaskFunc(wc *mgr.WorkerCtx) error { } // Retry quicker when failed. - module.mgr.Delay("trigger compat self-check", selfcheckTaskRetryAfter, selfcheckTaskFunc) - // task.Schedule(time.Now().Add(selfcheckTaskRetryAfter)) + module.selfcheckTask.Delay(selfcheckTaskRetryAfter) return nil } diff --git a/service/core/base/logs.go b/service/core/base/logs.go index b78c75e17..c04a048fe 100644 --- a/service/core/base/logs.go +++ b/service/core/base/logs.go @@ -19,10 +19,7 @@ const ( ) func registerLogCleaner() { - module.mgr.Delay("log cleaner delay", 15*time.Minute, func(w *mgr.WorkerCtx) error { - module.mgr.Repeat("log cleaner", 24*time.Hour, logCleaner) - return nil - }) + _ = module.mgr.Delay("log cleaner", 15*time.Minute, logCleaner, nil).Repeat(24 * time.Hour) } func logCleaner(_ *mgr.WorkerCtx) error { diff --git a/service/intel/customlists/module.go b/service/intel/customlists/module.go index c7c59edf9..a21250a44 100644 --- a/service/intel/customlists/module.go +++ b/service/intel/customlists/module.go @@ -21,6 +21,8 @@ type CustomList struct { mgr *mgr.Manager instance instance + updateFilterListTask *mgr.Task + States *mgr.StateMgr } @@ -28,6 +30,8 @@ func (cl *CustomList) Start(m *mgr.Manager) error { cl.mgr = m cl.States = mgr.NewStateMgr(m) + cl.updateFilterListTask = m.NewTask("update custom filter list", checkAndUpdateFilterList, nil) + if err := prep(); err != nil { return err } @@ -69,7 +73,7 @@ func prep() error { Path: "customlists/update", Write: api.PermitUser, ActionFunc: func(ar *api.Request) (msg string, err error) { - errCheck := checkAndUpdateFilterList() + errCheck := checkAndUpdateFilterList(nil) if errCheck != nil { return "", errCheck } @@ -88,8 +92,8 @@ func start() error { // Register to hook to update after config change. module.instance.Config().EventConfigChange.AddCallback( "update custom filter list", - func(_ *mgr.WorkerCtx, _ struct{}) (bool, error) { - if err := checkAndUpdateFilterList(); !errors.Is(err, ErrNotConfigured) { + func(wc *mgr.WorkerCtx, _ struct{}) (bool, error) { + if err := checkAndUpdateFilterList(wc); !errors.Is(err, ErrNotConfigured) { return false, err } return false, nil @@ -97,15 +101,12 @@ func start() error { ) // Create parser task and enqueue for execution. "checkAndUpdateFilterList" will schedule the next execution. - module.mgr.Repeat("intel/customlists:file-update-check", 20*time.Second, func(_ *mgr.WorkerCtx) error { - _ = checkAndUpdateFilterList() - return nil - }) + module.updateFilterListTask.Delay(20 * time.Second).Repeat(1 * time.Minute) return nil } -func checkAndUpdateFilterList() error { +func checkAndUpdateFilterList(_ *mgr.WorkerCtx) error { filterListLock.Lock() defer filterListLock.Unlock() @@ -115,10 +116,6 @@ func checkAndUpdateFilterList() error { return ErrNotConfigured } - // Schedule next update check - // TODO(vladimir): The task is set to repeate evry few seconds does. Is there another way to make it better? - // parserTask.Schedule(time.Now().Add(1 * time.Minute)) - // Try to get file info modifiedTime := time.Now() if fileInfo, err := os.Stat(filePath); err == nil { diff --git a/service/intel/geoip/database.go b/service/intel/geoip/database.go index 6aee3d944..1bffed787 100644 --- a/service/intel/geoip/database.go +++ b/service/intel/geoip/database.go @@ -148,7 +148,7 @@ func (upd *updateWorker) triggerUpdate() { func (upd *updateWorker) start() { upd.once.Do(func() { - module.mgr.Delay("geoip-updater", time.Second*10, upd.run) + module.mgr.Delay("geoip-updater", time.Second*10, upd.run, nil) }) } diff --git a/service/mgr/task.go b/service/mgr/task.go new file mode 100644 index 000000000..385767c26 --- /dev/null +++ b/service/mgr/task.go @@ -0,0 +1,164 @@ +package mgr + +import ( + "sync" + "time" +) + +type taskMode int + +const ( + taskModeOnDemand taskMode = iota + taskModeDelay + taskModeRepeat +) + +// Task holds info about a task that can be scheduled for execution later. +type Task struct { + name string + runChannel chan struct{} + + tickerMutex sync.Mutex + mode taskMode + runTicker *time.Ticker + repeatDuration time.Duration + + mgr *Manager +} + +// NewTask creates a new task that can be scheduled for execution later. +// By default error/panic will be logged. For custom behavior supply errorFn, the argument is optional. +func (m *Manager) NewTask(name string, taskFn func(*WorkerCtx) error, errorFn func(c *WorkerCtx, err error, panicInfo string)) *Task { + t := &Task{ + name: name, + runChannel: make(chan struct{}), + mgr: m, + mode: taskModeOnDemand, + repeatDuration: 0, + } + + go t.taskLoop(taskFn, errorFn) + + return t +} + +func (t *Task) initTicker(duration time.Duration) { + t.runTicker = time.NewTicker(duration) + go func() { + for { + select { + case <-t.runTicker.C: + t.tickerMutex.Lock() + + // Handle execution + switch t.mode { + case taskModeDelay: + // Run once and disable delay + t.Go() + if t.repeatDuration == 0 { + t.mode = taskModeOnDemand + // Reset the timer with a large value so it does not eat unnecessary resources, + t.runTicker.Reset(24 * time.Hour) + } else { + // Repeat was called, switch to repeat mode + t.mode = taskModeRepeat + t.runTicker.Reset(t.repeatDuration) + } + case taskModeRepeat: + t.Go() + case taskModeOnDemand: + // On Demand is triggered only when the Go function as called + } + + t.tickerMutex.Unlock() + case <-t.mgr.Done(): + return + } + } + }() +} + +func (t *Task) stopTicker() { + t.tickerMutex.Lock() + defer t.tickerMutex.Unlock() + if t.runTicker != nil { + t.runTicker.Stop() + t.runTicker = nil + } +} + +func (t *Task) taskLoop(fn func(*WorkerCtx) error, errorFn func(*WorkerCtx, error, string)) { + t.mgr.workerStart() + defer t.mgr.workerDone() + defer t.stopTicker() + + w := &WorkerCtx{ + logger: t.mgr.logger.With("worker", t.name), + } + for { + // Wait for a signal to run. + select { + case <-t.runChannel: + case <-w.Done(): + return + } + + panicInfo, err := t.mgr.runWorker(w, fn) + if err != nil { + // Handle error/panic + if panicInfo != "" { + t.mgr.Error( + "worker failed", + "err", err, + "file", panicInfo, + ) + } else { + t.mgr.Error( + "worker failed", + "err", err, + ) + } + if errorFn != nil { + errorFn(w, err, panicInfo) + } + } + } +} + +// Go will send request for the task to run and return immediately. +func (t *Task) Go() { + t.runChannel <- struct{}{} +} + +// Delay will schedule the task to run after the given delay. +// If there is active repeating, it will be pause until the delay has elapsed. +func (t *Task) Delay(delay time.Duration) *Task { + t.tickerMutex.Lock() + defer t.tickerMutex.Unlock() + t.mode = taskModeDelay + if t.runTicker == nil { + t.initTicker(delay) + } else { + t.runTicker.Reset(delay) + } + return t +} + +// Repeat will schedule the task to run every time duration elapses. +// If Delay was called before, the repeating will start after the first delay has elapsed. +func (t *Task) Repeat(duration time.Duration) *Task { + t.tickerMutex.Lock() + defer t.tickerMutex.Unlock() + t.repeatDuration = duration + + if t.mode != taskModeDelay { + t.mode = taskModeRepeat + + if t.runTicker == nil { + t.initTicker(duration) + } else { + t.runTicker.Reset(duration) + } + } + return t +} diff --git a/service/mgr/worker.go b/service/mgr/worker.go index 95933660f..395e45968 100644 --- a/service/mgr/worker.go +++ b/service/mgr/worker.go @@ -139,7 +139,7 @@ func (m *Manager) manageWorker(name string, fn func(w *WorkerCtx) error) { return case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): - // A canceled context or dexceeded eadline also means that the worker is finished. + // A canceled context or exceeded deadline also means that the worker is finished. return default: @@ -195,26 +195,6 @@ func (m *Manager) manageWorker(name string, fn func(w *WorkerCtx) error) { } } -// Delay starts the given function delayed in a goroutine (as a "worker"). -// The worker context has -// - A separate context which is canceled when the functions returns. -// - Access to named structure logging. -// - Given function is re-run after failure (with backoff). -// - Panic catching. -// - Flow control helpers. -func (m *Manager) Delay(name string, delay time.Duration, fn func(w *WorkerCtx) error) { - go m.delayWorker(name, delay, fn) -} - -func (m *Manager) delayWorker(name string, delay time.Duration, fn func(w *WorkerCtx) error) { - select { - case <-time.After(delay): - case <-m.ctx.Done(): - return - } - m.manageWorker(name, fn) -} - // Do directly executes the given function (as a "worker"). // The worker context has // - A separate context which is canceled when the functions returns. @@ -308,103 +288,21 @@ func (m *Manager) runWorker(w *WorkerCtx, fn func(w *WorkerCtx) error) (panicInf // The worker context has // - A separate context which is canceled when the functions returns. // - Access to named structure logging. -// - Given function is re-run after failure (with backoff). -// - Panic catching. +// - By default error/panic will be logged. For custom behavior supply errorFn, the argument is optional. // - Flow control helpers. -func (m *Manager) Repeat(name string, period time.Duration, fn func(w *WorkerCtx) error) { - go m.manageRepeatedWorker(name, period, fn) +func (m *Manager) Repeat(name string, period time.Duration, fn func(w *WorkerCtx) error, errorFn func(c *WorkerCtx, err error, panicInfo string)) *Task { + t := m.NewTask(name, fn, errorFn) + return t.Repeat(period) } -func (m *Manager) manageRepeatedWorker(name string, period time.Duration, fn func(w *WorkerCtx) error) { - m.workerStart() - defer m.workerDone() - - w := &WorkerCtx{ - logger: m.logger.With("worker", name), - } - - repeatTick := time.NewTicker(period) - execCnt := 0 - - backoff := time.Second - failCnt := 0 - -repeat: - for { - // Wait for repeat period. - if execCnt > 0 { - select { - case <-repeatTick.C: - case <-m.ctx.Done(): - return - } - } - - // Execute function. - execCnt++ - panicInfo, err := m.runWorker(w, fn) - - switch { - case err == nil: - // No error means that the worker is finished. - continue repeat - - case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): - // A canceled context or exceeded deadline also means that the worker is finished. - continue repeat - - default: - // Any other errors triggers a restart with backoff. - - // If manager is stopping, just log error and return. - if m.IsDone() { - if panicInfo != "" { - m.Error( - "worker failed", - "err", err, - "file", panicInfo, - ) - } else { - m.Error( - "worker failed", - "err", err, - ) - } - return - } - - // Count failure and increase backoff (up to limit), - failCnt++ - backoff *= 2 - if backoff > time.Minute { - backoff = time.Minute - } - - // Log error and retry after backoff duration. - if panicInfo != "" { - m.Error( - "repeated worker failed", - "execCnt", execCnt, - "failCnt", failCnt, - "backoff", backoff, - "err", err, - "file", panicInfo, - ) - } else { - m.Error( - "repeated worker failed", - "execCnt", execCnt, - "failCnt", failCnt, - "backoff", backoff, - "err", err, - ) - } - select { - case <-time.After(backoff): - case <-m.ctx.Done(): - return - } - repeatTick.Reset(period) - } - } +// Delay starts the given function delayed in a goroutine (as a "worker"). +// The worker context has +// - A separate context which is canceled when the functions returns. +// - Access to named structure logging. +// - By default error/panic will be logged. For custom behavior supply errorFn, the argument is optional. +// - Panic catching. +// - Flow control helpers. +func (m *Manager) Delay(name string, period time.Duration, fn func(w *WorkerCtx) error, errorFn func(c *WorkerCtx, err error, panicInfo string)) *Task { + t := m.NewTask(name, fn, errorFn) + return t.Delay(period) } diff --git a/service/netquery/module_api.go b/service/netquery/module_api.go index 831b72bee..9f1fb6ca9 100644 --- a/service/netquery/module_api.go +++ b/service/netquery/module_api.go @@ -258,12 +258,9 @@ func (nq *NetQuery) Start(m *mgr.Manager) error { } }) - nq.mgr.Delay("network history cleaner delay", 10*time.Minute, func(_ *mgr.WorkerCtx) error { - nq.mgr.Repeat("network history cleaner delay", 1*time.Hour, func(w *mgr.WorkerCtx) error { - return nq.Store.CleanupHistory(w.Ctx()) - }) - return nil - }) + nq.mgr.Delay("network history cleaner", 10*time.Minute, func(w *mgr.WorkerCtx) error { + return nq.Store.CleanupHistory(w.Ctx()) + }, nil).Repeat(1 * time.Hour) // For debugging, provide a simple direct SQL query interface using // the runtime database. @@ -316,7 +313,7 @@ var ( shimLoaded atomic.Bool ) -// New returns a new NetQuery module. +// NewModule returns a new NetQuery module. func NewModule(instance instance) (*NetQuery, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") diff --git a/service/profile/config-update.go b/service/profile/config-update.go index b0436be6b..1c7382413 100644 --- a/service/profile/config-update.go +++ b/service/profile/config-update.go @@ -33,7 +33,7 @@ func registerConfigUpdater() error { const globalConfigProfileErrorID = "profile:global-profile-error" -func updateGlobalConfigProfile(ctx context.Context) error { +func updateGlobalConfigProfile(_ context.Context) error { cfgLock.Lock() defer cfgLock.Unlock() @@ -134,10 +134,10 @@ func updateGlobalConfigProfile(ctx context.Context) error { // Create task after first failure. // Schedule task. - module.mgr.Delay("retry updating global config profile", 15*time.Second, + _ = module.mgr.Delay("retry updating global config profile", 15*time.Second, func(w *mgr.WorkerCtx) error { return updateGlobalConfigProfile(w.Ctx()) - }) + }, nil) // Add module warning to inform user. module.States.Add(mgr.State{ diff --git a/service/resolver/failing.go b/service/resolver/failing.go index 8c562642f..e950f1e31 100644 --- a/service/resolver/failing.go +++ b/service/resolver/failing.go @@ -73,6 +73,9 @@ func (brc *BasicResolverConn) ResetFailure() { func checkFailingResolvers(wc *mgr.WorkerCtx) error { var resolvers []*Resolver + // Set next execution time. + module.failingResolverTask.Delay(time.Duration(nameserverRetryRate()) * time.Second) + // Make a copy of the resolver list. func() { resolversLock.Lock() @@ -115,8 +118,5 @@ func checkFailingResolvers(wc *mgr.WorkerCtx) error { } } - // Set next execution time. - module.mgr.Delay("check failing resolvers", time.Duration(nameserverRetryRate())*time.Second, checkFailingResolvers) - return nil } diff --git a/service/resolver/main.go b/service/resolver/main.go index 7083cbcc1..5b77be84d 100644 --- a/service/resolver/main.go +++ b/service/resolver/main.go @@ -25,6 +25,9 @@ type ResolverModule struct { mgr *mgr.Manager instance instance + failingResolverTask *mgr.Task + suggestUsingStaleCacheTask *mgr.Task + States *mgr.StateMgr } @@ -99,7 +102,8 @@ func start() error { }) // Check failing resolvers regularly and when the network changes. - module.mgr.Do("check failing resolvers", checkFailingResolvers) + module.failingResolverTask = module.mgr.NewTask("check failing resolvers", checkFailingResolvers, nil) + module.failingResolverTask.Go() module.instance.NetEnv().EventNetworkChange.AddCallback( "check failing resolvers", func(wc *mgr.WorkerCtx, _ struct{}) (bool, error) { @@ -107,7 +111,8 @@ func start() error { return false, nil }) - module.mgr.Go("suggest using stale cache", suggestUsingStaleCacheTask) + module.suggestUsingStaleCacheTask = module.mgr.NewTask("suggest using stale cache", suggestUsingStaleCacheTask, nil) + module.suggestUsingStaleCacheTask.Go() module.mgr.Go( "mdns handler", diff --git a/service/resolver/metrics.go b/service/resolver/metrics.go index f0062c9ef..2068eafb6 100644 --- a/service/resolver/metrics.go +++ b/service/resolver/metrics.go @@ -103,7 +103,7 @@ func suggestUsingStaleCacheTask(_ *mgr.WorkerCtx) error { } if scheduleNextCall { - module.mgr.Delay("suggest using stale cache", 2*time.Minute, suggestUsingStaleCacheTask) + _ = module.suggestUsingStaleCacheTask.Delay(2 * time.Minute) } resetSlowQueriesSensorValue() return nil diff --git a/service/updates/main.go b/service/updates/main.go index 5f4d1dee6..55a47d5c2 100644 --- a/service/updates/main.go +++ b/service/updates/main.go @@ -111,7 +111,7 @@ func prep() error { func start() error { initConfig() - module.mgr.Repeat("automatic restart", 10*time.Minute, automaticRestart) + _ = module.mgr.Repeat("automatic restart", 10*time.Minute, automaticRestart, nil) module.instance.Config().EventConfigChange.AddCallback("update registry config", updateRegistryConfig) @@ -190,23 +190,15 @@ func start() error { } // start updater task - // FIXME: remove - // updateTask = module.NewTask("updater", func(ctx context.Context, task *modules.Task) error { - // return checkForUpdates(ctx) - // }) + module.updateTask = module.mgr.NewTask("updater", checkForUpdates, nil) if !disableTaskSchedule { - module.mgr.Repeat("updater", 30*time.Minute, checkForUpdates) - - // FIXME: remove - // updateTask. - // Repeat(updateTaskRepeatDuration). - // MaxDelay(30 * time.Minute) + _ = module.updateTask.Repeat(30 * time.Minute) } - // if updateASAP { - // updateTask.StartASAP() - // } + if updateASAP { + module.updateTask.Go() + } // react to upgrades if err := initUpgrader(); err != nil { @@ -221,10 +213,6 @@ func start() error { // TriggerUpdate queues the update task to execute ASAP. func TriggerUpdate(forceIndexCheck, downloadAll bool) error { switch { - // FIXME(vladimir): provide alternative for this - // case !module.Online(): - // updateASAP = true - case !forceIndexCheck && !enableSoftwareUpdates() && !enableIntelUpdates(): return errors.New("automatic updating is disabled") @@ -237,12 +225,7 @@ func TriggerUpdate(forceIndexCheck, downloadAll bool) error { } // If index check if forced, start quicker. - // FIXME(vladimir): provide alternative for this - // if forceIndexCheck { - // updateTask.StartASAP() - // } else { - // updateTask.Queue() - // } + module.updateTask.Go() } log.Debugf("updates: triggering update to run as soon as possible") diff --git a/service/updates/module.go b/service/updates/module.go index 8e62bc67b..ef5858eba 100644 --- a/service/updates/module.go +++ b/service/updates/module.go @@ -16,6 +16,8 @@ type Updates struct { instance instance shutdownFunc func(exitCode int) + updateTask *mgr.Task + EventResourcesUpdated *mgr.EventMgr[struct{}] EventVersionsUpdated *mgr.EventMgr[struct{}] diff --git a/service/updates/upgrader.go b/service/updates/upgrader.go index 889ccf4c9..093647d65 100644 --- a/service/updates/upgrader.go +++ b/service/updates/upgrader.go @@ -182,14 +182,14 @@ func upgradeHub() error { // Increase update checks in order to detect aborts better. if !disableTaskSchedule { - updateTask.Repeat(10 * time.Minute) + module.updateTask.Repeat(10 * time.Minute) } } else { AbortRestart() // Set update task schedule back to normal. if !disableTaskSchedule { - updateTask.Repeat(updateTaskRepeatDuration) + module.updateTask.Repeat(updateTaskRepeatDuration) } } diff --git a/spn/access/module.go b/spn/access/module.go index 7f5ada2df..5db5a44c1 100644 --- a/spn/access/module.go +++ b/spn/access/module.go @@ -20,12 +20,16 @@ type Access struct { mgr *mgr.Manager instance instance + updateAccountTask *mgr.Task + EventAccountUpdate *mgr.EventMgr[struct{}] } func (a *Access) Start(m *mgr.Manager) error { a.mgr = m a.EventAccountUpdate = mgr.NewEventMgr[struct{}](AccountUpdateEvent, m) + a.updateAccountTask = m.NewTask("update account", UpdateAccount, nil) + if err := prep(); err != nil { return err } @@ -83,10 +87,7 @@ func start() error { loadTokens() // Register new task. - module.mgr.Delay("update account delayed", 1*time.Minute, func(_ *mgr.WorkerCtx) error { - module.mgr.Repeat("update account", 24*time.Hour, UpdateAccount) - return nil - }) + module.updateAccountTask.Delay(1 * time.Minute) } return nil @@ -94,10 +95,6 @@ func start() error { func stop() error { if conf.Client() { - // Stop account update task. - // accountUpdateTask.Cancel() - // accountUpdateTask = nil - // Store tokens to database. storeTokens() } @@ -110,10 +107,13 @@ func stop() error { // UpdateAccount updates the user account and fetches new tokens, if needed. func UpdateAccount(_ *mgr.WorkerCtx) error { //, task *modules.Task) error { + // Schedule next call this will change if other conditions are met bellow. + module.updateAccountTask.Delay(24 * time.Hour) + // Retry sooner if the token issuer is failing. defer func() { if tokenIssuerIsFailing.IsSet() { - module.mgr.Delay("update account", tokenIssuerRetryDuration, UpdateAccount) + module.updateAccountTask.Delay(tokenIssuerRetryDuration) } }() @@ -145,16 +145,14 @@ func UpdateAccount(_ *mgr.WorkerCtx) error { //, task *modules.Task) error { case time.Until(*u.Subscription.EndsAt) < 24*time.Hour && time.Since(*u.Subscription.EndsAt) < 24*time.Hour: // Update account every hour for 24h hours before and after the subscription ends. - // TODO(vladimir): Go rotunes will leak if this is called more then once. Figure out a way to test if this is already running. - module.mgr.Delay("update account", 1*time.Hour, UpdateAccount) + module.updateAccountTask.Delay(1 * time.Hour) case u.Subscription.NextBillingDate == nil: // No auto-subscription. case time.Until(*u.Subscription.NextBillingDate) < 24*time.Hour && time.Since(*u.Subscription.NextBillingDate) < 24*time.Hour: // Update account every hour 24h hours before and after the next billing date. - // TODO(vladimir): Go rotunes will leak if this is called more then once. Figure out a way to test if this is already running. - module.mgr.Delay("update account", 1*time.Hour, UpdateAccount) + module.updateAccountTask.Delay(1 * time.Hour) } return nil @@ -188,7 +186,7 @@ func tokenIssuerFailed() { // return // } - module.mgr.Delay("update account", tokenIssuerRetryDuration, UpdateAccount) + module.updateAccountTask.Delay(tokenIssuerRetryDuration) } // IsLoggedIn returns whether a User is currently logged in. diff --git a/spn/captain/hooks.go b/spn/captain/hooks.go index c16d46ecf..119f6d811 100644 --- a/spn/captain/hooks.go +++ b/spn/captain/hooks.go @@ -32,7 +32,7 @@ func handleCraneUpdate(crane *docks.Crane) { func updateConnectionStatus() { // Delay updating status for a better chance to combine multiple changes. - module.mgr.Delay("maintain public status", maintainStatusUpdateDelay, maintainPublicStatus) + module.maintainPublicStatus.Delay(maintainStatusUpdateDelay) // Check if we lost all connections and trigger a pending restart if we did. for _, crane := range docks.GetAllAssignedCranes() { diff --git a/spn/captain/module.go b/spn/captain/module.go index b6157dc0f..f71871f2b 100644 --- a/spn/captain/module.go +++ b/spn/captain/module.go @@ -35,7 +35,8 @@ type Captain struct { shutdownFunc func(exitCode int) - healthCheckTicker *mgr.SleepyTicker + healthCheckTicker *mgr.SleepyTicker + maintainPublicStatus *mgr.Task States *mgr.StateMgr EventSPNConnected *mgr.EventMgr[struct{}] @@ -44,6 +45,7 @@ type Captain struct { func (c *Captain) Start(m *mgr.Manager) error { c.mgr = m c.EventSPNConnected = mgr.NewEventMgr[struct{}](SPNConnectedEvent, m) + c.maintainPublicStatus = m.NewTask("maintain public status", maintainPublicStatus, nil) if err := prep(); err != nil { return err } @@ -177,10 +179,7 @@ func start() error { // network optimizer if conf.PublicHub() { - module.mgr.Delay("optimize network delay", 15*time.Second, func(_ *mgr.WorkerCtx) error { - module.mgr.Repeat("optimize network", 1*time.Minute, optimizeNetwork) - return nil - }) + module.mgr.Delay("optimize network", 15*time.Second, optimizeNetwork, nil).Repeat(1 * time.Minute) } // client + home hub manager diff --git a/spn/captain/public.go b/spn/captain/public.go index 62e19d251..427ff9aac 100644 --- a/spn/captain/public.go +++ b/spn/captain/public.go @@ -87,11 +87,11 @@ func loadPublicIdentity() (err error) { } func prepPublicIdentityMgmt() error { - module.mgr.Repeat("maintain public status", maintainStatusInterval, maintainPublicStatus) + module.maintainPublicStatus.Repeat(maintainStatusInterval) module.instance.Config().EventConfigChange.AddCallback("update public identity from config", func(wc *mgr.WorkerCtx, s struct{}) (cancel bool, err error) { - module.mgr.Delay("maintain public identity", 5*time.Minute, maintainPublicIdentity) + module.maintainPublicStatus.Delay(5 * time.Minute) return false, nil }) return nil @@ -99,7 +99,7 @@ func prepPublicIdentityMgmt() error { // TriggerHubStatusMaintenance queues the Hub status update task to be executed. func TriggerHubStatusMaintenance() { - module.mgr.Go("maintain public status", maintainPublicStatus) + module.maintainPublicStatus.Go() } func maintainPublicIdentity(ctx *mgr.WorkerCtx) error { diff --git a/spn/crew/module.go b/spn/crew/module.go index f96cee59c..3523a10b6 100644 --- a/spn/crew/module.go +++ b/spn/crew/module.go @@ -24,7 +24,7 @@ func (c *Crew) Stop(m *mgr.Manager) error { } func start() error { - module.mgr.Repeat("sticky cleaner", 10*time.Minute, cleanStickyHubs) + _ = module.mgr.Repeat("sticky cleaner", 10*time.Minute, cleanStickyHubs, nil) return registerMetrics() } diff --git a/spn/navigator/measurements.go b/spn/navigator/measurements.go index 2fd20abdd..f137c2b1f 100644 --- a/spn/navigator/measurements.go +++ b/spn/navigator/measurements.go @@ -1,11 +1,11 @@ package navigator import ( - "context" "sort" "time" "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/docks" "github.com/safing/portmaster/spn/terminal" ) @@ -24,7 +24,7 @@ const ( // 1000c -> 100h -> capped to 50h. ) -func (m *Map) measureHubs(ctx context.Context) error { +func (m *Map) measureHubs(wc *mgr.WorkerCtx) error { if home, _ := m.GetHome(); home == nil { log.Debug("spn/navigator: skipping measuring, no home hub set") return nil @@ -73,7 +73,7 @@ func (m *Map) measureHubs(ctx context.Context) error { } // Measure connection. - tErr := docks.MeasureHub(ctx, pin.Hub, checkWithTTL) + tErr := docks.MeasureHub(wc.Ctx(), pin.Hub, checkWithTTL) // Independent of outcome, recalculate the cost. latency, _ := pin.measurements.GetLatency() diff --git a/spn/navigator/module.go b/spn/navigator/module.go index 774dc670f..75c44bcb9 100644 --- a/spn/navigator/module.go +++ b/spn/navigator/module.go @@ -110,21 +110,12 @@ geoInitCheck: } // TODO: delete superseded hubs after x amount of time - module.mgr.Delay("update states delay", 3*time.Minute, func(w *mgr.WorkerCtx) error { - module.mgr.Repeat("update states", 1*time.Hour, Main.updateStates) - return nil - }) - module.mgr.Delay("update failing states delay", 3*time.Minute, func(w *mgr.WorkerCtx) error { - module.mgr.Repeat("update states", 1*time.Minute, Main.updateFailingStates) - return nil - }) + _ = module.mgr.Delay("update states", 3*time.Minute, Main.updateStates, nil).Repeat(1 * time.Hour) + _ = module.mgr.Delay("update failing states delay", 3*time.Minute, Main.updateFailingStates, nil).Repeat(1 * time.Minute) if conf.PublicHub() { // Only measure Hubs on public Hubs. - module.mgr.Delay("measure hubs delay", 5*time.Minute, func(w *mgr.WorkerCtx) error { - module.mgr.Repeat("measure hubs", 1*time.Minute, Main.updateFailingStates) - return nil - }) + module.mgr.Delay("measure hubs delay", 5*time.Minute, Main.measureHubs, nil).Repeat(1 * time.Minute) // Only register metrics on Hubs, as they only make sense there. err := registerMetrics() diff --git a/spn/patrol/module.go b/spn/patrol/module.go index 414c06165..78211a9be 100644 --- a/spn/patrol/module.go +++ b/spn/patrol/module.go @@ -24,7 +24,7 @@ func (p *Patrol) Start(m *mgr.Manager) error { p.EventChangeSignal = mgr.NewEventMgr[struct{}](ChangeSignalEventName, m) if conf.PublicHub() { - m.Repeat("connectivity test", 5*time.Minute, connectivityCheckTask) + m.Repeat("connectivity test", 5*time.Minute, connectivityCheckTask, nil) } return nil } From 7880b13070bc6f2ea57a9d53ec7224b3c4f43748 Mon Sep 17 00:00:00 2001 From: Daniel Date: Mon, 24 Jun 2024 15:26:44 +0200 Subject: [PATCH 12/56] [WIP] Add second take for scheduling workers --- service/mgr/scheduler.go | 191 +++++++++++++++++++++++++++++++++++++++ service/mgr/worker.go | 14 ++- 2 files changed, 202 insertions(+), 3 deletions(-) create mode 100644 service/mgr/scheduler.go diff --git a/service/mgr/scheduler.go b/service/mgr/scheduler.go new file mode 100644 index 000000000..25887a918 --- /dev/null +++ b/service/mgr/scheduler.go @@ -0,0 +1,191 @@ +package mgr + +import ( + "context" + "errors" + "sync/atomic" + "time" +) + +// Scheduler schedules a worker. +type Scheduler struct { + mgr *Manager + ctx *WorkerCtx + + name string + fn func(w *WorkerCtx) error + + run chan struct{} + eval chan struct{} + + delay atomic.Int64 + repeat atomic.Int64 + keepAlive atomic.Bool + + errorFn func(c *WorkerCtx, err error, panicInfo string) +} + +// NewScheduler creates a new scheduler for the given worker function. +// Errors and panic will only be logged by default. +// If custom behavior is required, supply an errorFn. +// If all scheduling has ended, the scheduler will end itself, +// including all related workers, except if keep-alive is enabled. +func (m *Manager) NewScheduler(name string, fn func(w *WorkerCtx) error, errorFn func(c *WorkerCtx, err error, panicInfo string)) *Scheduler { + // Create task context. + wCtx := &WorkerCtx{ + logger: m.logger.With("worker", name), + } + wCtx.ctx, wCtx.cancelCtx = context.WithCancel(m.Ctx()) + + s := &Scheduler{ + mgr: m, + ctx: wCtx, + name: name, + fn: fn, + run: make(chan struct{}, 1), + eval: make(chan struct{}, 1), + errorFn: errorFn, + } + + go s.taskMgr() + return s +} + +func (s *Scheduler) taskMgr() { + // If the task manager ends, end all descendents too. + defer s.ctx.cancelCtx() + + // Timers and tickers. + var ( + ticker *time.Ticker + nextExecute <-chan time.Time + changed bool + ) + defer func() { + if ticker != nil { + ticker.Stop() + } + }() + +manage: + for { + // Select timer / ticker. + switch { + case s.delay.Load() > 0: + if changed { + nextExecute = time.After(time.Duration(s.delay.Load())) + changed = false + } + + case s.repeat.Load() > 0: + if changed { + if ticker != nil { + ticker.Reset(time.Duration(s.repeat.Load())) + } else { + ticker = time.NewTicker(time.Duration(s.repeat.Load())) + } + nextExecute = ticker.C + changed = false + } + + case !s.keepAlive.Load(): + // If no delay or repeat is set, end task. + // Except, if explicitly set to be kept alive. + return + + default: + // No trigger is set, disable timed execution. + if ticker != nil { + ticker.Stop() + ticker = nil + } + nextExecute = nil + } + + // Wait for action or ticker. + select { + case <-s.run: + case <-nextExecute: + case <-s.eval: + changed = true + goto manage + case <-s.ctx.Done(): + return + } + + // Run worker. + panicInfo, err := s.mgr.runWorker(s.ctx, s.fn) + switch { + case err == nil: + // Continue with scheduling. + case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): + // Worker was canceled, continue with scheduling. + // A canceled context or exceeded deadline also means that the worker is finished. + + default: + // Log error and return. + if panicInfo != "" { + s.ctx.Error( + "worker failed", + "err", err, + "file", panicInfo, + ) + } else { + s.ctx.Error( + "worker failed", + "err", err, + ) + } + + // Execute error function, else, end the scheduler. + if s.errorFn != nil { + s.errorFn(s.ctx, err, panicInfo) + } else { + return + } + } + } +} + +// Go executes the worker immediately. +// If the worker is currently being executed, +// the next execution will commence afterwards. +func (s *Scheduler) Go() { + select { + case s.run <- struct{}{}: + default: + } +} + +// KeepAlive instructs the scheduler to not self-destruct, +// even if all scheduled work is complete. +func (s *Scheduler) KeepAlive() { + s.keepAlive.Store(true) +} + +// Stop immediately stops the scheduler and all related workers. +func (s *Scheduler) Stop() { + s.ctx.cancelCtx() +} + +// Delay will schedule the worker to run after the given duration. +// If set, the repeat schedule will continue afterwards. +// Disable the delay by passing 0. +func (s *Scheduler) Delay(duration time.Duration) { + s.delay.Store(int64(duration)) + s.check() +} + +// Repeat will repeatedly execute the worker using the given interval. +// Disable the repeating by passing 0. +func (s *Scheduler) Repeat(interval time.Duration) { + s.repeat.Store(int64(interval)) + s.check() +} + +func (s *Scheduler) check() { + select { + case s.eval <- struct{}{}: + default: + } +} diff --git a/service/mgr/worker.go b/service/mgr/worker.go index 395e45968..397c14e00 100644 --- a/service/mgr/worker.go +++ b/service/mgr/worker.go @@ -23,7 +23,8 @@ type WorkerCtx struct { ctx context.Context cancelCtx context.CancelFunc - logger *slog.Logger + scheduler *Scheduler // TODO: Attach to context instead? + logger *slog.Logger } // AddToCtx adds the WorkerCtx to the given context. @@ -52,6 +53,12 @@ func (w *WorkerCtx) Cancel() { w.cancelCtx() } +// Scheduler returns the scheduler the worker was started from. +// Returns nil if the worker is not associated with a scheduler. +func (w *WorkerCtx) Scheduler() *Scheduler { + return w.scheduler +} + // Done returns the context Done channel. func (w *WorkerCtx) Done() <-chan struct{} { return w.ctx.Done() @@ -208,6 +215,7 @@ func (m *Manager) Do(name string, fn func(w *WorkerCtx) error) error { // Create context. w := &WorkerCtx{ + ctx: m.Ctx(), logger: m.logger.With("worker", name), } @@ -242,7 +250,7 @@ func (m *Manager) Do(name string, fn func(w *WorkerCtx) error) error { func (m *Manager) runWorker(w *WorkerCtx, fn func(w *WorkerCtx) error) (panicInfo string, err error) { // Create worker context that is canceled when worker finished or dies. - w.ctx, w.cancelCtx = context.WithCancel(m.Ctx()) + w.ctx, w.cancelCtx = context.WithCancel(w.ctx) defer w.Cancel() // Recover from panic. @@ -269,7 +277,7 @@ func (m *Manager) runWorker(w *WorkerCtx, fn func(w *WorkerCtx) error) (panicInf foundPanic = true } } else { - if strings.Contains(line, "mycoria") { + if strings.Contains(line, "portmaster") { if i+1 < len(stackLines) { panicInfo = strings.SplitN(strings.TrimSpace(stackLines[i+1]), " ", 2)[0] } From f9eeae001fdefa4d3b8da4c46491555eb503749a Mon Sep 17 00:00:00 2001 From: Daniel Date: Mon, 24 Jun 2024 15:38:28 +0200 Subject: [PATCH 13/56] [WIP] Add FIXME for bugs in new scheduler --- service/mgr/scheduler.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/service/mgr/scheduler.go b/service/mgr/scheduler.go index 25887a918..405741514 100644 --- a/service/mgr/scheduler.go +++ b/service/mgr/scheduler.go @@ -78,6 +78,8 @@ manage: } case s.repeat.Load() > 0: + // FIXME: bug: race condition of multiple evals. + // FIXME: bug: After delay, changed will be false. if changed { if ticker != nil { ticker.Reset(time.Duration(s.repeat.Load())) @@ -108,7 +110,7 @@ manage: case <-nextExecute: case <-s.eval: changed = true - goto manage + continue manage case <-s.ctx.Done(): return } From 91f3c709832d0d8588e927f4b4f5ee384c9afa2d Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov Date: Mon, 24 Jun 2024 18:12:13 +0300 Subject: [PATCH 14/56] [WIP] Add minor improvements to scheduler --- service/mgr/scheduler.go | 40 ++++++---------------------------------- 1 file changed, 6 insertions(+), 34 deletions(-) diff --git a/service/mgr/scheduler.go b/service/mgr/scheduler.go index 405741514..85b09cbc0 100644 --- a/service/mgr/scheduler.go +++ b/service/mgr/scheduler.go @@ -52,43 +52,20 @@ func (m *Manager) NewScheduler(name string, fn func(w *WorkerCtx) error, errorFn } func (s *Scheduler) taskMgr() { - // If the task manager ends, end all descendents too. + // If the task manager ends, end all descendants too. defer s.ctx.cancelCtx() - // Timers and tickers. - var ( - ticker *time.Ticker - nextExecute <-chan time.Time - changed bool - ) - defer func() { - if ticker != nil { - ticker.Stop() - } - }() - + // Timers. + var nextExecute <-chan time.Time manage: for { // Select timer / ticker. switch { - case s.delay.Load() > 0: - if changed { - nextExecute = time.After(time.Duration(s.delay.Load())) - changed = false - } + case s.delay.Swap(0) > 0: + nextExecute = time.After(time.Duration(s.delay.Load())) case s.repeat.Load() > 0: - // FIXME: bug: race condition of multiple evals. - // FIXME: bug: After delay, changed will be false. - if changed { - if ticker != nil { - ticker.Reset(time.Duration(s.repeat.Load())) - } else { - ticker = time.NewTicker(time.Duration(s.repeat.Load())) - } - nextExecute = ticker.C - changed = false - } + nextExecute = time.After(time.Duration(s.repeat.Load())) case !s.keepAlive.Load(): // If no delay or repeat is set, end task. @@ -97,10 +74,6 @@ manage: default: // No trigger is set, disable timed execution. - if ticker != nil { - ticker.Stop() - ticker = nil - } nextExecute = nil } @@ -109,7 +82,6 @@ manage: case <-s.run: case <-nextExecute: case <-s.eval: - changed = true continue manage case <-s.ctx.Done(): return From f6cb6b42a160d001b651b0e2da3bcc59102c9d50 Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov Date: Tue, 25 Jun 2024 18:21:06 +0300 Subject: [PATCH 15/56] [WIP] Add new worker scheduler --- service/mgr/scheduler.go | 165 ------------------- service/mgr/task.go | 164 ------------------- service/mgr/worker.go | 17 +- service/mgr/workermgr.go | 300 ++++++++++++++++++++++++++++++++++ service/mgr/workermgr_test.go | 149 +++++++++++++++++ 5 files changed, 459 insertions(+), 336 deletions(-) delete mode 100644 service/mgr/scheduler.go delete mode 100644 service/mgr/task.go create mode 100644 service/mgr/workermgr.go create mode 100644 service/mgr/workermgr_test.go diff --git a/service/mgr/scheduler.go b/service/mgr/scheduler.go deleted file mode 100644 index 85b09cbc0..000000000 --- a/service/mgr/scheduler.go +++ /dev/null @@ -1,165 +0,0 @@ -package mgr - -import ( - "context" - "errors" - "sync/atomic" - "time" -) - -// Scheduler schedules a worker. -type Scheduler struct { - mgr *Manager - ctx *WorkerCtx - - name string - fn func(w *WorkerCtx) error - - run chan struct{} - eval chan struct{} - - delay atomic.Int64 - repeat atomic.Int64 - keepAlive atomic.Bool - - errorFn func(c *WorkerCtx, err error, panicInfo string) -} - -// NewScheduler creates a new scheduler for the given worker function. -// Errors and panic will only be logged by default. -// If custom behavior is required, supply an errorFn. -// If all scheduling has ended, the scheduler will end itself, -// including all related workers, except if keep-alive is enabled. -func (m *Manager) NewScheduler(name string, fn func(w *WorkerCtx) error, errorFn func(c *WorkerCtx, err error, panicInfo string)) *Scheduler { - // Create task context. - wCtx := &WorkerCtx{ - logger: m.logger.With("worker", name), - } - wCtx.ctx, wCtx.cancelCtx = context.WithCancel(m.Ctx()) - - s := &Scheduler{ - mgr: m, - ctx: wCtx, - name: name, - fn: fn, - run: make(chan struct{}, 1), - eval: make(chan struct{}, 1), - errorFn: errorFn, - } - - go s.taskMgr() - return s -} - -func (s *Scheduler) taskMgr() { - // If the task manager ends, end all descendants too. - defer s.ctx.cancelCtx() - - // Timers. - var nextExecute <-chan time.Time -manage: - for { - // Select timer / ticker. - switch { - case s.delay.Swap(0) > 0: - nextExecute = time.After(time.Duration(s.delay.Load())) - - case s.repeat.Load() > 0: - nextExecute = time.After(time.Duration(s.repeat.Load())) - - case !s.keepAlive.Load(): - // If no delay or repeat is set, end task. - // Except, if explicitly set to be kept alive. - return - - default: - // No trigger is set, disable timed execution. - nextExecute = nil - } - - // Wait for action or ticker. - select { - case <-s.run: - case <-nextExecute: - case <-s.eval: - continue manage - case <-s.ctx.Done(): - return - } - - // Run worker. - panicInfo, err := s.mgr.runWorker(s.ctx, s.fn) - switch { - case err == nil: - // Continue with scheduling. - case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): - // Worker was canceled, continue with scheduling. - // A canceled context or exceeded deadline also means that the worker is finished. - - default: - // Log error and return. - if panicInfo != "" { - s.ctx.Error( - "worker failed", - "err", err, - "file", panicInfo, - ) - } else { - s.ctx.Error( - "worker failed", - "err", err, - ) - } - - // Execute error function, else, end the scheduler. - if s.errorFn != nil { - s.errorFn(s.ctx, err, panicInfo) - } else { - return - } - } - } -} - -// Go executes the worker immediately. -// If the worker is currently being executed, -// the next execution will commence afterwards. -func (s *Scheduler) Go() { - select { - case s.run <- struct{}{}: - default: - } -} - -// KeepAlive instructs the scheduler to not self-destruct, -// even if all scheduled work is complete. -func (s *Scheduler) KeepAlive() { - s.keepAlive.Store(true) -} - -// Stop immediately stops the scheduler and all related workers. -func (s *Scheduler) Stop() { - s.ctx.cancelCtx() -} - -// Delay will schedule the worker to run after the given duration. -// If set, the repeat schedule will continue afterwards. -// Disable the delay by passing 0. -func (s *Scheduler) Delay(duration time.Duration) { - s.delay.Store(int64(duration)) - s.check() -} - -// Repeat will repeatedly execute the worker using the given interval. -// Disable the repeating by passing 0. -func (s *Scheduler) Repeat(interval time.Duration) { - s.repeat.Store(int64(interval)) - s.check() -} - -func (s *Scheduler) check() { - select { - case s.eval <- struct{}{}: - default: - } -} diff --git a/service/mgr/task.go b/service/mgr/task.go deleted file mode 100644 index 385767c26..000000000 --- a/service/mgr/task.go +++ /dev/null @@ -1,164 +0,0 @@ -package mgr - -import ( - "sync" - "time" -) - -type taskMode int - -const ( - taskModeOnDemand taskMode = iota - taskModeDelay - taskModeRepeat -) - -// Task holds info about a task that can be scheduled for execution later. -type Task struct { - name string - runChannel chan struct{} - - tickerMutex sync.Mutex - mode taskMode - runTicker *time.Ticker - repeatDuration time.Duration - - mgr *Manager -} - -// NewTask creates a new task that can be scheduled for execution later. -// By default error/panic will be logged. For custom behavior supply errorFn, the argument is optional. -func (m *Manager) NewTask(name string, taskFn func(*WorkerCtx) error, errorFn func(c *WorkerCtx, err error, panicInfo string)) *Task { - t := &Task{ - name: name, - runChannel: make(chan struct{}), - mgr: m, - mode: taskModeOnDemand, - repeatDuration: 0, - } - - go t.taskLoop(taskFn, errorFn) - - return t -} - -func (t *Task) initTicker(duration time.Duration) { - t.runTicker = time.NewTicker(duration) - go func() { - for { - select { - case <-t.runTicker.C: - t.tickerMutex.Lock() - - // Handle execution - switch t.mode { - case taskModeDelay: - // Run once and disable delay - t.Go() - if t.repeatDuration == 0 { - t.mode = taskModeOnDemand - // Reset the timer with a large value so it does not eat unnecessary resources, - t.runTicker.Reset(24 * time.Hour) - } else { - // Repeat was called, switch to repeat mode - t.mode = taskModeRepeat - t.runTicker.Reset(t.repeatDuration) - } - case taskModeRepeat: - t.Go() - case taskModeOnDemand: - // On Demand is triggered only when the Go function as called - } - - t.tickerMutex.Unlock() - case <-t.mgr.Done(): - return - } - } - }() -} - -func (t *Task) stopTicker() { - t.tickerMutex.Lock() - defer t.tickerMutex.Unlock() - if t.runTicker != nil { - t.runTicker.Stop() - t.runTicker = nil - } -} - -func (t *Task) taskLoop(fn func(*WorkerCtx) error, errorFn func(*WorkerCtx, error, string)) { - t.mgr.workerStart() - defer t.mgr.workerDone() - defer t.stopTicker() - - w := &WorkerCtx{ - logger: t.mgr.logger.With("worker", t.name), - } - for { - // Wait for a signal to run. - select { - case <-t.runChannel: - case <-w.Done(): - return - } - - panicInfo, err := t.mgr.runWorker(w, fn) - if err != nil { - // Handle error/panic - if panicInfo != "" { - t.mgr.Error( - "worker failed", - "err", err, - "file", panicInfo, - ) - } else { - t.mgr.Error( - "worker failed", - "err", err, - ) - } - if errorFn != nil { - errorFn(w, err, panicInfo) - } - } - } -} - -// Go will send request for the task to run and return immediately. -func (t *Task) Go() { - t.runChannel <- struct{}{} -} - -// Delay will schedule the task to run after the given delay. -// If there is active repeating, it will be pause until the delay has elapsed. -func (t *Task) Delay(delay time.Duration) *Task { - t.tickerMutex.Lock() - defer t.tickerMutex.Unlock() - t.mode = taskModeDelay - if t.runTicker == nil { - t.initTicker(delay) - } else { - t.runTicker.Reset(delay) - } - return t -} - -// Repeat will schedule the task to run every time duration elapses. -// If Delay was called before, the repeating will start after the first delay has elapsed. -func (t *Task) Repeat(duration time.Duration) *Task { - t.tickerMutex.Lock() - defer t.tickerMutex.Unlock() - t.repeatDuration = duration - - if t.mode != taskModeDelay { - t.mode = taskModeRepeat - - if t.runTicker == nil { - t.initTicker(duration) - } else { - t.runTicker.Reset(duration) - } - } - return t -} diff --git a/service/mgr/worker.go b/service/mgr/worker.go index 397c14e00..e6fa0fb4f 100644 --- a/service/mgr/worker.go +++ b/service/mgr/worker.go @@ -23,7 +23,7 @@ type WorkerCtx struct { ctx context.Context cancelCtx context.CancelFunc - scheduler *Scheduler // TODO: Attach to context instead? + workerMgr *WorkerMgr // TODO: Attach to context instead? logger *slog.Logger } @@ -55,8 +55,8 @@ func (w *WorkerCtx) Cancel() { // Scheduler returns the scheduler the worker was started from. // Returns nil if the worker is not associated with a scheduler. -func (w *WorkerCtx) Scheduler() *Scheduler { - return w.scheduler +func (w *WorkerCtx) WorkerMgr() *WorkerMgr { + return w.workerMgr } // Done returns the context Done channel. @@ -124,6 +124,7 @@ func (w *WorkerCtx) LogAttrs(level slog.Level, msg string, attrs ...slog.Attr) { // - Panic catching. // - Flow control helpers. func (m *Manager) Go(name string, fn func(w *WorkerCtx) error) { + m.logger.Log(m.ctx, slog.LevelInfo, "worker started", "name", name) go m.manageWorker(name, fn) } @@ -134,6 +135,7 @@ func (m *Manager) manageWorker(name string, fn func(w *WorkerCtx) error) { w := &WorkerCtx{ logger: m.logger.With("worker", name), } + w.ctx = m.ctx backoff := time.Second failCnt := 0 @@ -298,8 +300,9 @@ func (m *Manager) runWorker(w *WorkerCtx, fn func(w *WorkerCtx) error) (panicInf // - Access to named structure logging. // - By default error/panic will be logged. For custom behavior supply errorFn, the argument is optional. // - Flow control helpers. -func (m *Manager) Repeat(name string, period time.Duration, fn func(w *WorkerCtx) error, errorFn func(c *WorkerCtx, err error, panicInfo string)) *Task { - t := m.NewTask(name, fn, errorFn) +// - Repeat is intended for long running tasks that are mostly idle. +func (m *Manager) Repeat(name string, period time.Duration, fn func(w *WorkerCtx) error) *WorkerMgr { + t := m.NewWorkerMgr(name, fn, nil) return t.Repeat(period) } @@ -310,7 +313,7 @@ func (m *Manager) Repeat(name string, period time.Duration, fn func(w *WorkerCtx // - By default error/panic will be logged. For custom behavior supply errorFn, the argument is optional. // - Panic catching. // - Flow control helpers. -func (m *Manager) Delay(name string, period time.Duration, fn func(w *WorkerCtx) error, errorFn func(c *WorkerCtx, err error, panicInfo string)) *Task { - t := m.NewTask(name, fn, errorFn) +func (m *Manager) Delay(name string, period time.Duration, fn func(w *WorkerCtx) error) *WorkerMgr { + t := m.NewWorkerMgr(name, fn, nil) return t.Delay(period) } diff --git a/service/mgr/workermgr.go b/service/mgr/workermgr.go new file mode 100644 index 000000000..9ba716d03 --- /dev/null +++ b/service/mgr/workermgr.go @@ -0,0 +1,300 @@ +package mgr + +import ( + "context" + "errors" + "sync" + "time" +) + +// WorkerMgr schedules a worker. +type WorkerMgr struct { + mgr *Manager + ctx *WorkerCtx + + // Definition. + name string + fn func(w *WorkerCtx) error + errorFn func(c *WorkerCtx, err error, panicInfo string) + + // Manual trigger. + run chan struct{} + + // Actions. + actionLock sync.Mutex + selectAction chan struct{} + delay *workerMgrDelay + repeat *workerMgrRepeat + keepAlive *workerMgrNoop +} + +type taskAction interface { + Wait() <-chan time.Time + Ack() +} + +// Delay. +type workerMgrDelay struct { + s *WorkerMgr + timer *time.Timer +} + +func (s *WorkerMgr) newDelay(duration time.Duration) *workerMgrDelay { + return &workerMgrDelay{ + s: s, + timer: time.NewTimer(duration), + } +} +func (sd *workerMgrDelay) Wait() <-chan time.Time { return sd.timer.C } + +func (sd *workerMgrDelay) Ack() { + sd.s.actionLock.Lock() + defer sd.s.actionLock.Unlock() + + // Remove delay, as it can only fire once. + sd.s.delay = nil + + // Reset repeat. + sd.s.repeat.Reset() + + // Stop timer. + sd.timer.Stop() +} + +func (sd *workerMgrDelay) Stop() { + if sd == nil { + return + } + sd.timer.Stop() +} + +// Repeat. +type workerMgrRepeat struct { + ticker *time.Ticker + interval time.Duration +} + +func (s *WorkerMgr) newRepeat(interval time.Duration) *workerMgrRepeat { + return &workerMgrRepeat{ + ticker: time.NewTicker(interval), + interval: interval, + } +} + +func (sr *workerMgrRepeat) Wait() <-chan time.Time { return sr.ticker.C } +func (sr *workerMgrRepeat) Ack() {} + +func (sr *workerMgrRepeat) Reset() { + if sr == nil { + return + } + sr.ticker.Reset(sr.interval) +} + +func (sr *workerMgrRepeat) Stop() { + if sr == nil { + return + } + sr.ticker.Stop() +} + +// Noop. +type workerMgrNoop struct{} + +func (sn *workerMgrNoop) Wait() <-chan time.Time { return nil } +func (sn *workerMgrNoop) Ack() {} + +// NewWorkerMgr creates a new scheduler for the given worker function. +// Errors and panic will only be logged by default. +// If custom behavior is required, supply an errorFn. +// If all scheduling has ended, the scheduler will end itself, +// including all related workers, except if keep-alive is enabled. +func (m *Manager) NewWorkerMgr(name string, fn func(w *WorkerCtx) error, errorFn func(c *WorkerCtx, err error, panicInfo string)) *WorkerMgr { + // Create task context. + wCtx := &WorkerCtx{ + logger: m.logger.With("worker", name), + } + wCtx.ctx, wCtx.cancelCtx = context.WithCancel(m.Ctx()) + + s := &WorkerMgr{ + mgr: m, + ctx: wCtx, + name: name, + fn: fn, + errorFn: errorFn, + run: make(chan struct{}, 1), + selectAction: make(chan struct{}, 1), + } + + go s.taskMgr() + return s +} + +func (s *WorkerMgr) taskMgr() { + s.mgr.workerStart() + defer s.mgr.workerDone() + + // If the task manager ends, end all descendants too. + defer s.ctx.cancelCtx() + + // Timers and tickers. + var ( + action taskAction + ) + defer func() { + s.delay.Stop() + s.repeat.Stop() + }() + + // Wait for the first action. + select { + case <-s.selectAction: + case <-s.ctx.Done(): + return + } + +manage: + for { + // Select action. + func() { + s.actionLock.Lock() + defer s.actionLock.Unlock() + + switch { + case s.delay != nil: + action = s.delay + case s.repeat != nil: + action = s.repeat + case s.keepAlive != nil: + action = s.keepAlive + default: + action = nil + } + }() + if action == nil { + return + } + + // Wait for trigger or action. + select { + case <-action.Wait(): + action.Ack() + // Time-triggered execution. + case <-s.run: + // Manually triggered execution. + case <-s.selectAction: + // Re-select action. + continue manage + case <-s.ctx.Done(): + // Abort! + return + } + + // Run worker. + wCtx := &WorkerCtx{ + logger: s.mgr.logger.With("worker", s.name), + } + wCtx.ctx, wCtx.cancelCtx = context.WithCancel(s.mgr.Ctx()) + panicInfo, err := s.mgr.runWorker(wCtx, s.fn) + + switch { + case err == nil: + // Continue with scheduling. + case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): + // Worker was canceled, continue with scheduling. + // A canceled context or exceeded deadline also means that the worker is finished. + + default: + // Log error and return. + if panicInfo != "" { + s.ctx.Error( + "worker failed", + "err", err, + "file", panicInfo, + ) + } else { + s.ctx.Error( + "worker failed", + "err", err, + ) + } + + // Delegate error handling to the error function, otherwise just continue the scheduler. + // The error handler can stop the scheduler if it wants to. + if s.errorFn != nil { + s.errorFn(s.ctx, err, panicInfo) + } + } + } +} + +// Go executes the worker immediately. +// If the worker is currently being executed, +// the next execution will commence afterwards. +// Can only be called after calling one of Delay(), Repeat() or KeepAlive(). +func (s *WorkerMgr) Go() { + s.actionLock.Lock() + defer s.actionLock.Unlock() + + // Reset repeat if set. + s.repeat.Reset() + + // Stop delay if set. + s.delay.Stop() + s.delay = nil + + // Send run command + select { + case s.run <- struct{}{}: + default: + } +} + +// Stop immediately stops the scheduler and all related workers. +func (s *WorkerMgr) Stop() { + s.ctx.cancelCtx() +} + +// Delay will schedule the worker to run after the given duration. +// If set, the repeat schedule will continue afterwards. +// Disable the delay by passing 0. +func (s *WorkerMgr) Delay(duration time.Duration) *WorkerMgr { + s.actionLock.Lock() + defer s.actionLock.Unlock() + + s.delay.Stop() + s.delay = s.newDelay(duration) + + s.check() + return s +} + +// Repeat will repeatedly execute the worker using the given interval. +// Disable repeating by passing 0. +func (s *WorkerMgr) Repeat(interval time.Duration) *WorkerMgr { + s.actionLock.Lock() + defer s.actionLock.Unlock() + + s.repeat.Stop() + s.repeat = s.newRepeat(interval) + + s.check() + return s +} + +// KeepAlive instructs the scheduler to not self-destruct, +// even if all scheduled work is complete. +func (s *WorkerMgr) KeepAlive() *WorkerMgr { + s.actionLock.Lock() + defer s.actionLock.Unlock() + + s.keepAlive = &workerMgrNoop{} + return s +} + +func (s *WorkerMgr) check() { + select { + case s.selectAction <- struct{}{}: + default: + } +} diff --git a/service/mgr/workermgr_test.go b/service/mgr/workermgr_test.go new file mode 100644 index 000000000..758d14499 --- /dev/null +++ b/service/mgr/workermgr_test.go @@ -0,0 +1,149 @@ +package mgr + +import ( + "sync/atomic" + "testing" + "time" +) + +func TestWorkerMgrDelay(t *testing.T) { + m := New("DelayTest") + + value := atomic.Bool{} + value.Store(false) + + // Create a task that will after 1 second. + m.NewWorkerMgr("test", func(w *WorkerCtx) error { + value.Store(true) + return nil + }, nil).Delay(1 * time.Second) + + // Check if value is set after 1 second and not before or after. + iterations := 0 + for !value.Load() { + iterations += 1 + time.Sleep(10 * time.Millisecond) + } + + // 5% difference is acceptable since time.Sleep can't be perfect and it may very on different computers. + if iterations < 95 || iterations > 105 { + t.Errorf("WorkerMgr did not delay for a whole second it=%d", iterations) + } +} + +func TestWorkerMgrRepeat(t *testing.T) { + m := New("RepeatTest") + + value := atomic.Bool{} + value.Store(false) + + // Create a task that should repeat every 100 milliseconds. + m.NewWorkerMgr("test", func(w *WorkerCtx) error { + value.Store(true) + return nil + }, nil).Repeat(100 * time.Millisecond) + + // Check 10 consecutive runs they should be delayed for around 100 milliseconds each. + for range 10 { + iterations := 0 + for !value.Load() { + iterations += 1 + time.Sleep(10 * time.Millisecond) + } + + // 10% difference is acceptable at this scale since time.Sleep can't be perfect and it may very on different computers. + if iterations < 9 || iterations > 11 { + t.Errorf("Worker was not delayed for a 100 milliseconds it=%d", iterations) + return + } + // Reset value + value.Store(false) + } +} + +func TestWorkerMgrDelayAndRepeat(t *testing.T) { + m := New("DelayAndRepeatTest") + + value := atomic.Bool{} + value.Store(false) + + // Create a task that should delay for 1 second and then repeat every 100 milliseconds. + m.NewWorkerMgr("test", func(w *WorkerCtx) error { + value.Store(true) + return nil + }, nil).Delay(1 * time.Second).Repeat(100 * time.Millisecond) + + iterations := 0 + for !value.Load() { + iterations += 1 + time.Sleep(10 * time.Millisecond) + } + + // 5% difference is acceptable since time.Sleep can't be perfect and it may very on different computers. + if iterations < 95 || iterations > 105 { + t.Errorf("WorkerMgr did not delay for a whole second it=%d", iterations) + } + + // Reset value + value.Store(false) + + // Check 10 consecutive runs they should be delayed for around 100 milliseconds each. + for range 10 { + iterations = 0 + for !value.Load() { + iterations += 1 + time.Sleep(10 * time.Millisecond) + } + + // 10% difference is acceptable at this scale since time.Sleep can't be perfect and it may very on different computers. + if iterations < 9 || iterations > 11 { + t.Errorf("Worker was not delayed for a 100 milliseconds it=%d", iterations) + return + } + // Reset value + value.Store(false) + } +} + +func TestWorkerMgrRepeatAndDelay(t *testing.T) { + m := New("RepeatAndDelayTest") + + value := atomic.Bool{} + value.Store(false) + + // Create a task that should delay for 1 second and then repeat every 100 milliseconds but with reverse command order. + m.NewWorkerMgr("test", func(w *WorkerCtx) error { + value.Store(true) + return nil + }, nil).Repeat(100 * time.Millisecond).Delay(1 * time.Second) + + iterations := 0 + for !value.Load() { + iterations += 1 + time.Sleep(10 * time.Millisecond) + } + + // 5% difference is acceptable since time.Sleep can't be perfect and it may very on different computers. + if iterations < 95 || iterations > 105 { + t.Errorf("WorkerMgr did not delay for a whole second it=%d", iterations) + } + // Reset value + value.Store(false) + + // Check 10 consecutive runs they should be delayed for around 100 milliseconds each. + for range 10 { + iterations := 0 + for !value.Load() { + iterations += 1 + time.Sleep(10 * time.Millisecond) + } + + // 10% difference is acceptable at this scale since time.Sleep can't be perfect and it may very on different computers. + if iterations < 9 || iterations > 11 { + t.Errorf("Worker was not delayed for a 100 milliseconds it=%d", iterations) + return + } + // Reset value + value.Store(false) + } +} From 3c9b636f91d97ada0d83896c93e8fd7d9987b23d Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov Date: Tue, 25 Jun 2024 18:31:01 +0300 Subject: [PATCH 16/56] [WIP] Fix more bug related to new module system --- base/api/client/message.go | 2 +- base/api/database.go | 2 +- base/api/main.go | 6 +- base/config/main.go | 4 +- base/config/module.go | 8 +- base/database/dbmodule/db.go | 4 + base/database/dbmodule/maintenance.go | 6 +- base/database/record/base.go | 2 +- base/database/record/meta-bench_test.go | 2 +- base/database/record/wrapper.go | 2 +- base/info/module/flags.go | 4 +- base/rng/entropy.go | 2 +- cmds/portmaster-core/main.go | 17 ++-- cmds/portmaster-start/logs.go | 2 +- go.mod | 53 +++++----- go.sum | 125 ++++++++++++----------- service/broadcasts/module.go | 2 +- service/compat/module.go | 10 +- service/core/base/logs.go | 2 +- service/firewall/module.go | 1 + service/instance.go | 127 ++++++++++++++---------- service/intel/customlists/module.go | 6 +- service/intel/geoip/database.go | 2 +- service/netenv/main.go | 7 +- service/netquery/module_api.go | 4 +- service/profile/config-update.go | 2 +- service/profile/fingerprint.go | 2 +- service/resolver/failing.go | 2 +- service/resolver/main.go | 10 +- service/sync/util.go | 2 +- service/ui/module.go | 5 + service/updates/main.go | 10 +- service/updates/module.go | 2 +- service/updates/restart.go | 15 +-- service/updates/upgrader.go | 4 +- spn/access/module.go | 20 ++-- spn/access/op_auth.go | 2 +- spn/access/token/pblind.go | 2 +- spn/access/token/token.go | 2 +- spn/access/zones.go | 11 +- spn/captain/module.go | 8 +- spn/captain/op_gossip.go | 2 +- spn/captain/op_gossip_query.go | 2 +- spn/captain/op_publish.go | 2 +- spn/crew/module.go | 2 +- spn/crew/op_connect.go | 2 +- spn/crew/op_ping.go | 2 +- spn/docks/bandwidth_test.go | 2 +- spn/docks/controller.go | 2 +- spn/docks/crane.go | 2 +- spn/docks/crane_establish.go | 2 +- spn/docks/crane_init.go | 2 +- spn/docks/crane_terminal.go | 2 +- spn/docks/crane_verify.go | 2 +- spn/docks/op_capacity.go | 2 +- spn/docks/op_expand.go | 2 +- spn/docks/op_latency.go | 2 +- spn/docks/op_sync_state.go | 2 +- spn/docks/op_whoami.go | 2 +- spn/docks/terminal_expansion.go | 2 +- spn/hub/update.go | 2 +- spn/navigator/module.go | 6 +- spn/patrol/module.go | 2 +- spn/terminal/init.go | 2 +- spn/terminal/module.go | 3 +- spn/terminal/msg.go | 2 +- spn/terminal/msgtypes.go | 2 +- spn/terminal/operation.go | 2 +- spn/terminal/operation_counter.go | 2 +- spn/terminal/terminal.go | 4 +- spn/terminal/terminal_test.go | 2 +- spn/terminal/testing.go | 2 +- 72 files changed, 299 insertions(+), 271 deletions(-) diff --git a/base/api/client/message.go b/base/api/client/message.go index 0927eb035..85754e230 100644 --- a/base/api/client/message.go +++ b/base/api/client/message.go @@ -6,8 +6,8 @@ import ( "github.com/tevino/abool" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/structures/container" ) // ErrMalformedMessage is returned when a malformed message was encountered. diff --git a/base/api/database.go b/base/api/database.go index ee1e7eae1..295559dcd 100644 --- a/base/api/database.go +++ b/base/api/database.go @@ -12,7 +12,6 @@ import ( "github.com/tidwall/gjson" "github.com/tidwall/sjson" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/database" "github.com/safing/portmaster/base/database/iterator" "github.com/safing/portmaster/base/database/query" @@ -21,6 +20,7 @@ import ( "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/mgr" + "github.com/safing/structures/container" ) const ( diff --git a/base/api/main.go b/base/api/main.go index d50c643cb..62b32ac7b 100644 --- a/base/api/main.go +++ b/base/api/main.go @@ -7,7 +7,6 @@ import ( "os" "time" - "github.com/safing/portbase/modules" "github.com/safing/portmaster/service/mgr" ) @@ -27,7 +26,8 @@ func init() { func prep() error { if exportEndpoints { - modules.SetCmdLineOperation(exportEndpointsCmd) + // FIXME(vladimir): migrate + // modules.SetCmdLineOperation(exportEndpointsCmd) } if getDefaultListenAddress() == "" { @@ -64,7 +64,7 @@ func start() error { // start api auth token cleaner if authFnSet.IsSet() { - _ = module.mgr.Repeat("clean api sessions", 5*time.Minute, cleanSessions, nil) + _ = module.mgr.Repeat("clean api sessions", 5*time.Minute, cleanSessions) } return registerEndpointBridgeDB() diff --git a/base/config/main.go b/base/config/main.go index 671e8a2b3..a1b3b19fd 100644 --- a/base/config/main.go +++ b/base/config/main.go @@ -10,7 +10,6 @@ import ( "path/filepath" "sort" - "github.com/safing/portbase/modules" "github.com/safing/portmaster/base/dataroot" "github.com/safing/portmaster/base/utils" "github.com/safing/portmaster/base/utils/debug" @@ -47,7 +46,8 @@ func prep() error { } if exportConfig { - modules.SetCmdLineOperation(exportConfigCmd) + // FIXME(vladimir): migrate + // modules.SetCmdLineOperation(exportConfigCmd) } return registerBasicOptions() diff --git a/base/config/module.go b/base/config/module.go index 2f3031d2b..e3bc96f4f 100644 --- a/base/config/module.go +++ b/base/config/module.go @@ -20,6 +20,10 @@ type Config struct { func (u *Config) Start(m *mgr.Manager) error { u.mgr = m u.EventConfigChange = mgr.NewEventMgr[struct{}](ChangeEvent, u.mgr) + + if err := prep(); err != nil { + return err + } return start() } @@ -39,10 +43,6 @@ func New(instance instance) (*Config, error) { return nil, errors.New("only one instance allowed") } - if err := prep(); err != nil { - return nil, err - } - module = &Config{ instance: instance, } diff --git a/base/database/dbmodule/db.go b/base/database/dbmodule/db.go index b92bf83da..77af7958a 100644 --- a/base/database/dbmodule/db.go +++ b/base/database/dbmodule/db.go @@ -20,6 +20,10 @@ func (dbm *DBModule) Start(m *mgr.Manager) error { return start() } +func (dbm *DBModule) Stop(m *mgr.Manager) error { + return stop() +} + var databaseStructureRoot *utils.DirStructure // SetDatabaseLocation sets the location of the database for initialization. Supply either a path or dir structure. diff --git a/base/database/dbmodule/maintenance.go b/base/database/dbmodule/maintenance.go index a899f4077..ce373fccb 100644 --- a/base/database/dbmodule/maintenance.go +++ b/base/database/dbmodule/maintenance.go @@ -9,9 +9,9 @@ import ( ) func startMaintenanceTasks() { - _ = module.mgr.Repeat("basic maintenance", 10*time.Minute, maintainBasic, nil) - _ = module.mgr.Repeat("thorough maintenance", 1*time.Hour, maintainThorough, nil) - _ = module.mgr.Repeat("record maintenance", 1*time.Hour, maintainRecords, nil) + _ = module.mgr.Repeat("basic maintenance", 10*time.Minute, maintainBasic) + _ = module.mgr.Repeat("thorough maintenance", 1*time.Hour, maintainThorough) + _ = module.mgr.Repeat("record maintenance", 1*time.Hour, maintainRecords) } func maintainBasic(ctx *mgr.WorkerCtx) error { diff --git a/base/database/record/base.go b/base/database/record/base.go index deacd78bf..347255a1f 100644 --- a/base/database/record/base.go +++ b/base/database/record/base.go @@ -3,10 +3,10 @@ package record import ( "errors" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/database/accessor" "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/log" + "github.com/safing/structures/container" ) // TODO(ppacher): diff --git a/base/database/record/meta-bench_test.go b/base/database/record/meta-bench_test.go index f3c048542..bf1824db2 100644 --- a/base/database/record/meta-bench_test.go +++ b/base/database/record/meta-bench_test.go @@ -21,9 +21,9 @@ import ( "testing" "time" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/formats/varint" + "github.com/safing/structures/container" ) var testMeta = &Meta{ diff --git a/base/database/record/wrapper.go b/base/database/record/wrapper.go index b8505baff..5204ffa1a 100644 --- a/base/database/record/wrapper.go +++ b/base/database/record/wrapper.go @@ -5,10 +5,10 @@ import ( "fmt" "sync" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/database/accessor" "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/formats/varint" + "github.com/safing/structures/container" ) // Wrapper wraps raw data and implements the Record interface. diff --git a/base/info/module/flags.go b/base/info/module/flags.go index 2df0bbf2a..59e974ae7 100644 --- a/base/info/module/flags.go +++ b/base/info/module/flags.go @@ -6,8 +6,8 @@ import ( "fmt" "sync/atomic" - "github.com/safing/portbase/modules" "github.com/safing/portmaster/base/info" + "github.com/safing/portmaster/service/core/base" "github.com/safing/portmaster/service/mgr" ) @@ -28,7 +28,7 @@ func (i *Info) Start(m *mgr.Manager) error { } if printVersion() { - return modules.ErrCleanExit + return base.ErrCleanExit } return nil } diff --git a/base/rng/entropy.go b/base/rng/entropy.go index 7c78dc6a6..71f73845d 100644 --- a/base/rng/entropy.go +++ b/base/rng/entropy.go @@ -5,8 +5,8 @@ import ( "github.com/tevino/abool" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/service/mgr" + "github.com/safing/structures/container" ) const ( diff --git a/cmds/portmaster-core/main.go b/cmds/portmaster-core/main.go index 11a865c78..14c2d5379 100644 --- a/cmds/portmaster-core/main.go +++ b/cmds/portmaster-core/main.go @@ -2,6 +2,7 @@ package main import ( + "flag" "fmt" "runtime" @@ -22,11 +23,14 @@ import ( ) func main() { + flag.Parse() + // set information info.Set("Portmaster", "", "GPLv3") // Set default log level. log.SetLogLevel(log.WarningLevel) + log.Start() // Configure metrics. _ = metrics.SetNamespace("portmaster") @@ -37,6 +41,13 @@ func main() { // enable SPN client mode conf.EnableClient(true) + // Prep + err := base.GlobalPrep() + if err != nil { + fmt.Printf("global prep failed: %s\n", err) + return + } + // Create instance, err := service.New("2.0.0", &service.ServiceConfig{ ShutdownFunc: func(exitCode int) { @@ -47,12 +58,6 @@ func main() { fmt.Printf("error creating an instance: %s\n", err) return } - // Prep - err = base.GlobalPrep() - if err != nil { - fmt.Printf("global prep failed: %s\n", err) - return - } // Start err = instance.Group.Start() if err != nil { diff --git a/cmds/portmaster-start/logs.go b/cmds/portmaster-start/logs.go index d9b02c616..280ef8cc8 100644 --- a/cmds/portmaster-start/logs.go +++ b/cmds/portmaster-start/logs.go @@ -10,10 +10,10 @@ import ( "github.com/spf13/cobra" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/database/record" "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/info" + "github.com/safing/structures/container" ) func initializeLogFile(logFilePath string, identifier string, version string) *os.File { diff --git a/go.mod b/go.mod index a3ea7e29c..ac763a24e 100644 --- a/go.mod +++ b/go.mod @@ -6,8 +6,8 @@ go 1.22.0 replace github.com/tc-hib/winres => github.com/dhaavi/winres v0.2.2 require ( - fyne.io/systray v1.10.0 - github.com/VictoriaMetrics/metrics v1.33.1 + fyne.io/systray v1.11.0 + github.com/VictoriaMetrics/metrics v1.34.0 github.com/Xuanwo/go-locale v1.1.0 github.com/aead/serpent v0.0.0-20160714141033-fba169763ea6 github.com/agext/levenshtein v1.2.3 @@ -23,33 +23,34 @@ require ( github.com/florianl/go-conntrack v0.4.0 github.com/florianl/go-nfqueue v1.3.2 github.com/fogleman/gg v1.3.0 - github.com/fxamacker/cbor/v2 v2.6.0 + github.com/fxamacker/cbor/v2 v2.7.0 github.com/ghodss/yaml v1.0.0 github.com/godbus/dbus/v5 v5.1.0 github.com/gofrs/uuid v4.4.0+incompatible github.com/google/gopacket v1.1.19 github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/gorilla/mux v1.8.1 - github.com/gorilla/websocket v1.5.1 + github.com/gorilla/websocket v1.5.3 github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-version v1.7.0 github.com/jackc/puddle/v2 v2.2.1 github.com/mat/besticon v3.12.0+incompatible - github.com/miekg/dns v1.1.59 + github.com/miekg/dns v1.1.61 github.com/mitchellh/copystructure v1.2.0 github.com/mitchellh/go-server-timing v1.0.1 github.com/mr-tron/base58 v1.2.0 - github.com/oschwald/maxminddb-golang v1.12.0 + github.com/oschwald/maxminddb-golang v1.13.0 github.com/r3labs/diff/v3 v3.0.1 github.com/rot256/pblind v0.0.0-20231024115251-cd3f239f28c1 - github.com/safing/jess v0.3.3 + github.com/safing/jess v0.3.4 github.com/safing/portbase v0.19.5 github.com/safing/portmaster-android/go v0.0.0-20230830120134-3226ceac3bec + github.com/safing/structures v1.1.0 github.com/seehuhn/fortuna v1.0.1 github.com/shirou/gopsutil v3.21.11+incompatible - github.com/spf13/cobra v1.8.0 + github.com/spf13/cobra v1.8.1 github.com/spkg/zipfs v0.7.1 - github.com/stretchr/testify v1.8.4 + github.com/stretchr/testify v1.9.0 github.com/tannerryan/ring v1.1.2 github.com/tc-hib/winres v0.3.1 github.com/tevino/abool v1.2.0 @@ -59,11 +60,11 @@ require ( github.com/vincent-petithory/dataurl v1.0.0 github.com/vmihailenco/msgpack/v5 v5.4.1 go.etcd.io/bbolt v1.3.10 - golang.org/x/exp v0.0.0-20240525044651-4c93da0ed11d - golang.org/x/image v0.16.0 - golang.org/x/net v0.25.0 + golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 + golang.org/x/image v0.17.0 + golang.org/x/net v0.26.0 golang.org/x/sync v0.7.0 - golang.org/x/sys v0.20.0 + golang.org/x/sys v0.21.0 gopkg.in/yaml.v3 v3.0.1 zombiezen.com/go/sqlite v1.3.0 ) @@ -72,7 +73,7 @@ require ( github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 // indirect github.com/aead/ecdh v0.2.0 // indirect github.com/alessio/shellescape v1.4.2 // indirect - github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/danieljoos/wincred v1.2.1 // indirect github.com/dgraph-io/ristretto v0.1.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect @@ -82,15 +83,15 @@ require ( github.com/godbus/dbus v4.1.0+incompatible // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f // indirect - github.com/golang/glog v1.2.0 // indirect - github.com/golang/protobuf v1.5.3 // indirect + github.com/golang/glog v1.2.1 // indirect + github.com/golang/protobuf v1.5.4 // indirect github.com/google/btree v1.1.2 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/josharian/native v1.1.0 // indirect - github.com/klauspost/cpuid/v2 v2.2.7 // indirect + github.com/klauspost/cpuid/v2 v2.2.8 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mdlayher/netlink v1.7.2 // indirect github.com/mdlayher/socket v0.5.1 // indirect @@ -112,18 +113,18 @@ require ( github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/x448/float16 v0.8.4 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect - github.com/zalando/go-keyring v0.2.4 // indirect + github.com/zalando/go-keyring v0.2.5 // indirect github.com/zeebo/blake3 v0.2.3 // indirect - golang.org/x/crypto v0.23.0 // indirect - golang.org/x/mod v0.17.0 // indirect - golang.org/x/text v0.15.0 // indirect + golang.org/x/crypto v0.24.0 // indirect + golang.org/x/mod v0.18.0 // indirect + golang.org/x/text v0.16.0 // indirect golang.org/x/time v0.5.0 // indirect - golang.org/x/tools v0.21.0 // indirect - google.golang.org/protobuf v1.32.0 // indirect + golang.org/x/tools v0.22.0 // indirect + google.golang.org/protobuf v1.34.2 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect - gvisor.dev/gvisor v0.0.0-20240524212851-a244eff8ad49 // indirect - modernc.org/libc v1.50.9 // indirect + gvisor.dev/gvisor v0.0.0-20240622065613-cd3efc65190a // indirect + modernc.org/libc v1.53.3 // indirect modernc.org/mathutil v1.6.0 // indirect modernc.org/memory v1.8.0 // indirect - modernc.org/sqlite v1.29.10 // indirect + modernc.org/sqlite v1.30.1 // indirect ) diff --git a/go.sum b/go.sum index a7eb61a2b..144bd2981 100644 --- a/go.sum +++ b/go.sum @@ -1,13 +1,13 @@ cloud.google.com/go v0.16.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -fyne.io/systray v1.10.0 h1:Yr1D9Lxeiw3+vSuZWPlaHC8BMjIHZXJKkek706AfYQk= -fyne.io/systray v1.10.0/go.mod h1:oM2AQqGJ1AMo4nNqZFYU8xYygSBZkW2hmdJ7n4yjedE= +fyne.io/systray v1.11.0 h1:D9HISlxSkx+jHSniMBR6fCFOUjk1x/OOOJLa9lJYAKg= +fyne.io/systray v1.11.0/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs= github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96 h1:cTp8I5+VIoKjsnZuH8vjyaysT/ses3EvZeaV/1UkF2M= github.com/AndreasBriese/bbloom v0.0.0-20190825152654-46b345b51c96/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v0.4.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= -github.com/VictoriaMetrics/metrics v1.33.1 h1:CNV3tfm2Kpv7Y9W3ohmvqgFWPR55tV2c7M2U6OIo+UM= -github.com/VictoriaMetrics/metrics v1.33.1/go.mod h1:r7hveu6xMdUACXvB8TYdAj8WEsKzWB0EkpJN+RDtOf8= +github.com/VictoriaMetrics/metrics v1.34.0 h1:0i8k/gdOJdSoZB4Z9pikVnVQXfhcIvnG7M7h2WaQW2w= +github.com/VictoriaMetrics/metrics v1.34.0/go.mod h1:r7hveu6xMdUACXvB8TYdAj8WEsKzWB0EkpJN+RDtOf8= github.com/Xuanwo/go-locale v1.1.0 h1:51gUxhxl66oXAjI9uPGb2O0qwPECpriKQb2hl35mQkg= github.com/Xuanwo/go-locale v1.1.0/go.mod h1:UKrHoZB3FPIk9wIG2/tVSobnHgNnceGSH3Y8DY5cASs= github.com/aead/ecdh v0.2.0 h1:pYop54xVaq/CEREFEcukHRZfTdjiWvYIsZDXXrBapQQ= @@ -30,8 +30,8 @@ github.com/brianvoe/gofakeit v3.18.0+incompatible h1:wDOmHc9DLG4nRjUVVaxA+CEglKO github.com/brianvoe/gofakeit v3.18.0+incompatible/go.mod h1:kfwdRA90vvNhPutZWfH7WPaDzUjz+CZFqG+rPkOjGOc= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= -github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cilium/ebpf v0.5.0/go.mod h1:4tRaxcgiL706VnOzHOdBlY8IEAIdxINsQBcU4xJJXRs= github.com/cilium/ebpf v0.7.0/go.mod h1:/oI2+1shJiTGAMgl6/RgJr36Eo1jzrRcAWbcXO2usCA= github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk= @@ -42,7 +42,7 @@ github.com/coreos/go-iptables v0.7.0 h1:XWM3V+MPRr5/q51NuWSgU0fqMad64Zyxs8ZUoMsa github.com/coreos/go-iptables v0.7.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= -github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/danieljoos/wincred v1.2.1 h1:dl9cBrupW8+r5250DYkYxocLeZ1Y4vB1kxgtjxw8GQs= github.com/danieljoos/wincred v1.2.1/go.mod h1:uGaFL9fDn3OLTvzCGulzE+SzjEe5NGlh5FdCcyfPwps= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -76,8 +76,8 @@ github.com/fsnotify/fsnotify v1.4.3-0.20170329110642-4da3e2cfbabc/go.mod h1:jwhs github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fxamacker/cbor v1.5.1 h1:XjQWBgdmQyqimslUh5r4tUGmoqzHmBFQOImkWGi2awg= github.com/fxamacker/cbor v1.5.1/go.mod h1:3aPGItF174ni7dDzd6JZ206H8cmr4GDNBGpPa971zsU= -github.com/fxamacker/cbor/v2 v2.6.0 h1:sU6J2usfADwWlYDAFhZBQ6TnLFBHxgesMrQfQgk1tWA= -github.com/fxamacker/cbor/v2 v2.6.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= +github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= +github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/garyburd/redigo v1.1.1-0.20170914051019-70e1b1943d4f/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY= github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= @@ -89,7 +89,6 @@ github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZs github.com/go-stack/stack v1.6.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/godbus/dbus v4.1.0+incompatible h1:WqqLRTsQic3apZUK9qC5sGNfXthmPXzUZ7nQPrNITa4= github.com/godbus/dbus v4.1.0+incompatible/go.mod h1:/YcGZj5zSblfDWMMoOzV4fas9FZnQYTkDnsGvmh2Grw= -github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gofrs/uuid v4.4.0+incompatible h1:3qXRTX8/NbyulANqlc0lchS1gqAVxRgsuW1YrTJupqA= @@ -100,13 +99,12 @@ github.com/golang/gddo v0.0.0-20180823221919-9d8ff1c67be5/go.mod h1:xEhNfoBDX1hz github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f h1:16RtHeWGkJMc80Etb8RPCcKevXGldr57+LOyZt8zOlg= github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f/go.mod h1:ijRvpgDJDI262hYq/IQVYgf8hd8IHUs93Ol0kvMBAx4= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/glog v1.2.0 h1:uCdmnmatrKCgMBlM4rMuJZWOkPDqdbZPnrMXDY4gI68= -github.com/golang/glog v1.2.0/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= +github.com/golang/glog v1.2.1 h1:OptwRhECazUx5ix5TTWC3EZhsZEHWcYWY4FQHTIubm4= +github.com/golang/glog v1.2.1/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= github.com/golang/lint v0.0.0-20170918230701-e5d664eb928e/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= -github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= @@ -133,8 +131,8 @@ github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGa github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= -github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= -github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gregjones/httpcache v0.0.0-20170920190843-316c5e0ff04e/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= @@ -167,8 +165,8 @@ github.com/jsimonetti/rtnetlink v0.0.0-20211022192332-93da33804786/go.mod h1:v4h github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= -github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= -github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/klauspost/cpuid/v2 v2.2.8 h1:+StwCXwm9PdpiEkPyzBXIy+M9KUb4ODm0Zarf1kS5BM= +github.com/klauspost/cpuid/v2 v2.2.8/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -208,8 +206,8 @@ github.com/mdlayher/socket v0.1.0/go.mod h1:mYV5YIZAfHh4dzDVzI8x8tWLWCliuX8Mon5A github.com/mdlayher/socket v0.1.1/go.mod h1:mYV5YIZAfHh4dzDVzI8x8tWLWCliuX8Mon5Awbj+qDs= github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= -github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs= -github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk= +github.com/miekg/dns v1.1.61 h1:nLxbwF3XxhwVSm8g9Dghm9MHPaUZuqhPiGL+675ZmEs= +github.com/miekg/dns v1.1.61/go.mod h1:mnAarhS3nWaW+NVP2wTkYVIZyHNJ098SJZUki3eykwQ= github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= @@ -225,8 +223,8 @@ github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdh github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= -github.com/oschwald/maxminddb-golang v1.12.0 h1:9FnTOD0YOhP7DGxGsq4glzpGy5+w7pq50AS6wALUMYs= -github.com/oschwald/maxminddb-golang v1.12.0/go.mod h1:q0Nob5lTCqyQ8WT6FYgS1L7PXKVVbgiymefNwIjPzgY= +github.com/oschwald/maxminddb-golang v1.13.0 h1:R8xBorY71s84yO06NgTmQvqvTvlS/bnYZrrWX1MElnU= +github.com/oschwald/maxminddb-golang v1.13.0/go.mod h1:BU0z8BfFVhi1LQaonTwwGQlsHUEu9pWNdMfmq4ztm0o= github.com/pelletier/go-toml v1.0.1-0.20170904195809-1d6b12b7cb29/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -244,12 +242,14 @@ github.com/rot256/pblind v0.0.0-20231024115251-cd3f239f28c1 h1:vfAp3Jbca7Vt8axzm github.com/rot256/pblind v0.0.0-20231024115251-cd3f239f28c1/go.mod h1:2x8fbm9T+uTl919COhEVHKGkve1DnkrEnDbtGptZuW8= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/safing/jess v0.3.3 h1:0U0bWdO0sFCgox+nMOqISFrnJpVmi+VFOW1xdX6q3qw= -github.com/safing/jess v0.3.3/go.mod h1:t63qHB+4xd1HIv9MKN/qI2rc7ytvx7d6l4hbX7zxer0= +github.com/safing/jess v0.3.4 h1:/p6ensqEUn2jI/z1EB9JUdwH4MJQirh/C9jEwNBzxw8= +github.com/safing/jess v0.3.4/go.mod h1:+B6UJnXVxi406Wk08SDnoC5NNBL7t3N0vZGokEbkVQI= github.com/safing/portbase v0.19.5 h1:3/8odzlvb629tHPwdj/sthSeJcwZHYrqA6YuvNUZzNc= github.com/safing/portbase v0.19.5/go.mod h1:Qrh3ck+7VZloFmnozCs9Hj8godhJAi55cmiDiC7BwTc= github.com/safing/portmaster-android/go v0.0.0-20230830120134-3226ceac3bec h1:oSJY1seobofPwpMoJRkCgXnTwfiQWNfGMCPDfqgAEfg= github.com/safing/portmaster-android/go v0.0.0-20230830120134-3226ceac3bec/go.mod h1:abwyAQrZGemWbSh/aCD9nnkp0SvFFf/mGWkAbOwPnFE= +github.com/safing/structures v1.1.0 h1:QzHBQBjaZSLzw2f6PM4ibSmPcfBHAOB5CKJ+k4FYkhQ= +github.com/safing/structures v1.1.0/go.mod h1:QUrB74FcU41ahQ5oy3YNFCoSq+twE/n3+vNZc2K35II= github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/seehuhn/fortuna v1.0.1 h1:lu9+CHsmR0bZnx5Ay646XvCSRJ8PJTi5UYJwDBX68H0= @@ -269,8 +269,8 @@ github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B github.com/spf13/cast v1.1.0/go.mod h1:r2rcYCSwa1IExKTDiTfzaxqT2FNHs8hODu4LnUfgKEg= github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= -github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= -github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= +github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= +github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/jwalterweatherman v0.0.0-20170901151539-12bd96e66386/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= github.com/spf13/pflag v1.0.1-0.20170901120850-7aff26db30c1/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= @@ -282,15 +282,15 @@ github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DM github.com/spkg/zipfs v0.7.1 h1:+2X5lvNHTybnDMQZAIHgedRXZK1WXdc+94R/P5v2XWE= github.com/spkg/zipfs v0.7.1/go.mod h1:48LW+/Rh1G7aAav1ew1PdlYn52T+LM+ARmSHfDNJvg8= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tannerryan/ring v1.1.2 h1:iXayOjqHQOLzuy9GwSKuG3nhWfzQkldMlQivcgIr7gQ= github.com/tannerryan/ring v1.1.2/go.mod h1:DkELJEjbZhJBtFKR9Xziwj3HKZnb/knRgljNqp65vH4= github.com/tevino/abool v1.2.0 h1:heAkClL8H6w+mK5md9dzsuohKeXHUpY7Vw0ZCKW+huA= @@ -330,8 +330,8 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= -github.com/zalando/go-keyring v0.2.4 h1:wi2xxTqdiwMKbM6TWwi+uJCG/Tum2UV0jqaQhCa9/68= -github.com/zalando/go-keyring v0.2.4/go.mod h1:HL4k+OXQfJUWaMnqyuSOc0drfGPX2b51Du6K+MRgZMk= +github.com/zalando/go-keyring v0.2.5 h1:Bc2HHpjALryKD62ppdEzaFG6VxL6Bc+5v0LYpN8Lba8= +github.com/zalando/go-keyring v0.2.5/go.mod h1:HL4k+OXQfJUWaMnqyuSOc0drfGPX2b51Du6K+MRgZMk= github.com/zeebo/assert v1.1.0 h1:hU1L1vLTHsnO8x8c9KAR5GmM5QscxHg5RNU5z5qbUWY= github.com/zeebo/assert v1.1.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= github.com/zeebo/blake3 v0.2.3 h1:TFoLXsjeXqRNFxSbk35Dk4YtszE/MQQGK10BH4ptoTg= @@ -345,19 +345,19 @@ golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67/go.mod h1:6SG95UA2DQfeDnf golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= -golang.org/x/exp v0.0.0-20240525044651-4c93da0ed11d h1:N0hmiNbwsSNwHBAvR3QB5w25pUwH4tK0Y/RltD1j1h4= -golang.org/x/exp v0.0.0-20240525044651-4c93da0ed11d/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= -golang.org/x/image v0.16.0 h1:9kloLAKhUufZhA12l5fwnx2NZW39/we1UhBesW433jw= -golang.org/x/image v0.16.0/go.mod h1:ugSZItdV4nOxyqp56HmXwH0Ry0nBCpjnZdpDaIHdoPs= +golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= +golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= +golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 h1:yixxcjnhBmY0nkL253HFVIm0JsFHwrHdT3Yh6szTnfY= +golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8/go.mod h1:jj3sYF3dwk5D+ghuXyeI3r5MFf+NT2An6/9dOA95KSI= +golang.org/x/image v0.17.0 h1:nTRVVdajgB8zCMZVsViyzhnMKPwYeroEERRC64JuLco= +golang.org/x/image v0.17.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= -golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= -golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= +golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= @@ -379,8 +379,8 @@ golang.org/x/net v0.0.0-20211020060615-d418f374d309/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20211209124913-491a49abca63/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220107192237-5cfca573fb4d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= -golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= +golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= +golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/oauth2 v0.0.0-20170912212905-13449ad91cb2/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20170517211232-f52d1811a629/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -398,7 +398,6 @@ golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201009025420-dfb3f7c4e634/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201118182958-a01c418693c7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -425,8 +424,8 @@ golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= -golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= +golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -434,8 +433,8 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= -golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= golang.org/x/time v0.0.0-20170424234030-8be79e1e0910/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= @@ -445,8 +444,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.7/go.mod h1:LGqMHiF4EqQNHR1JncWGqT5BVaXmza+X+BDGol+dOxo= -golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw= -golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= +golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -455,10 +454,8 @@ google.golang.org/api v0.0.0-20170921000349-586095a6e407/go.mod h1:4mhQ8q/RsB7i+ google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/genproto v0.0.0-20170918111702-1e559d0a00ee/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/grpc v1.2.1-0.20170921194603-d4b75ebd4f9f/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I= -google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= +google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -468,20 +465,20 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gvisor.dev/gvisor v0.0.0-20240524212851-a244eff8ad49 h1:E4ibk9lM99Nqj8fVfZQeqwDR5A1nb4GejITW7TewvMU= -gvisor.dev/gvisor v0.0.0-20240524212851-a244eff8ad49/go.mod h1:sxc3Uvk/vHcd3tj7/DHVBoR5wvWT/MmRq2pj7HRJnwU= +gvisor.dev/gvisor v0.0.0-20240622065613-cd3efc65190a h1:nhh0326nShN7+CyCyinptYOFnk2YImbQ6XNfR7pwKC8= +gvisor.dev/gvisor v0.0.0-20240622065613-cd3efc65190a/go.mod h1:sxc3Uvk/vHcd3tj7/DHVBoR5wvWT/MmRq2pj7HRJnwU= honnef.co/go/tools v0.2.1/go.mod h1:lPVVZ2BS5TfnjLyizF7o7hv7j9/L+8cZY2hLyjP9cGY= honnef.co/go/tools v0.2.2/go.mod h1:lPVVZ2BS5TfnjLyizF7o7hv7j9/L+8cZY2hLyjP9cGY= -modernc.org/cc/v4 v4.21.2 h1:dycHFB/jDc3IyacKipCNSDrjIC0Lm1hyoWOZTRR20Lk= -modernc.org/cc/v4 v4.21.2/go.mod h1:HM7VJTZbUCR3rV8EYBi9wxnJ0ZBRiGE5OeGXNA0IsLQ= -modernc.org/ccgo/v4 v4.17.8 h1:yyWBf2ipA0Y9GGz/MmCmi3EFpKgeS7ICrAFes+suEbs= -modernc.org/ccgo/v4 v4.17.8/go.mod h1:buJnJ6Fn0tyAdP/dqePbrrvLyr6qslFfTbFrCuaYvtA= +modernc.org/cc/v4 v4.21.3 h1:2mhBdWKtivdFlLR1ecKXTljPG1mfvbByX7QKztAIJl8= +modernc.org/cc/v4 v4.21.3/go.mod h1:HM7VJTZbUCR3rV8EYBi9wxnJ0ZBRiGE5OeGXNA0IsLQ= +modernc.org/ccgo/v4 v4.18.1 h1:1zF5kPBFq/ZVTulBOKgQPQITdOzzyBUfC51gVYP62E4= +modernc.org/ccgo/v4 v4.18.1/go.mod h1:ao1fAxf9a2KEOL15WY8+yP3wnpaOpP/QuyFOZ9HJolM= modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE= modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ= modernc.org/gc/v2 v2.4.1 h1:9cNzOqPyMJBvrUipmynX0ZohMhcxPtMccYgGOJdOiBw= modernc.org/gc/v2 v2.4.1/go.mod h1:wzN5dK1AzVGoH6XOzc3YZ+ey/jPgYHLuVckd62P0GYU= -modernc.org/libc v1.50.9 h1:hIWf1uz55lorXQhfoEoezdUHjxzuO6ceshET/yWjSjk= -modernc.org/libc v1.50.9/go.mod h1:15P6ublJ9FJR8YQCGy8DeQ2Uwur7iW9Hserr/T3OFZE= +modernc.org/libc v1.53.3 h1:9O0aSLZuHPgp49we24NoFFteRgXNLGBAQ3TODrW3XLg= +modernc.org/libc v1.53.3/go.mod h1:kb+Erju4FfHNE59xd2fNpv5CBeAeej6fHbx8p8xaiyI= modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4= modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo= modernc.org/memory v1.8.0 h1:IqGTL6eFMaDZZhEWwcREgeMXYwmW83LYW8cROZYkg+E= @@ -490,8 +487,8 @@ modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4= modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= modernc.org/sortutil v1.2.0 h1:jQiD3PfS2REGJNzNCMMaLSp/wdMNieTbKX920Cqdgqc= modernc.org/sortutil v1.2.0/go.mod h1:TKU2s7kJMf1AE84OoiGppNHJwvB753OYfNl2WRb++Ss= -modernc.org/sqlite v1.29.10 h1:3u93dz83myFnMilBGCOLbr+HjklS6+5rJLx4q86RDAg= -modernc.org/sqlite v1.29.10/go.mod h1:ItX2a1OVGgNsFh6Dv60JQvGfJfTPHPVpV6DF59akYOA= +modernc.org/sqlite v1.30.1 h1:YFhPVfu2iIgUf9kuA1CR7iiHdcEEsI2i+yjRYHscyxk= +modernc.org/sqlite v1.30.1/go.mod h1:DUmsiWQDaAvU4abhc/N+djlom/L2o8f7gZ95RCvyoLU= modernc.org/strutil v1.2.0 h1:agBi9dp1I+eOnxXeiZawM8F4LawKv4NzGWSaLfyeNZA= modernc.org/strutil v1.2.0/go.mod h1:/mdcBmfOibveCTBxUl5B5l6W+TTH1FXPLHZE6bTosX0= modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= diff --git a/service/broadcasts/module.go b/service/broadcasts/module.go index 2ef89d948..1f6d3882b 100644 --- a/service/broadcasts/module.go +++ b/service/broadcasts/module.go @@ -55,7 +55,7 @@ func start() error { // Start broadcast notifier task. startOnce.Do(func() { - module.mgr.Repeat("broadcast notifier", 10*time.Minute, broadcastNotify, nil) + module.mgr.Repeat("broadcast notifier", 10*time.Minute, broadcastNotify) }) return nil diff --git a/service/compat/module.go b/service/compat/module.go index 1efaa690c..1f25d5271 100644 --- a/service/compat/module.go +++ b/service/compat/module.go @@ -17,7 +17,7 @@ type Compat struct { mgr *mgr.Manager instance instance - selfcheckTask *mgr.Task + selfcheckWorkerMgr *mgr.WorkerMgr } // Start starts the module. @@ -71,11 +71,11 @@ func start() error { startNotify() selfcheckNetworkChangedFlag.Refresh() - module.selfcheckTask = module.mgr.Repeat("compatibility self-check", 5*time.Minute, selfcheckTaskFunc, nil).Delay(selfcheckTaskRetryAfter) + module.selfcheckWorkerMgr = module.mgr.Repeat("compatibility self-check", 5*time.Minute, selfcheckTaskFunc).Delay(selfcheckTaskRetryAfter) - _ = module.mgr.Repeat("clean notify thresholds", 1*time.Hour, cleanNotifyThreshold, nil) + _ = module.mgr.Repeat("clean notify thresholds", 1*time.Hour, cleanNotifyThreshold) module.instance.NetEnv().EventNetworkChange.AddCallback("trigger compat self-check", func(_ *mgr.WorkerCtx, _ struct{}) (bool, error) { - module.selfcheckTask.Delay(selfcheckTaskRetryAfter) + module.selfcheckWorkerMgr.Delay(selfcheckTaskRetryAfter) return false, nil }) return nil @@ -122,7 +122,7 @@ func selfcheckTaskFunc(wc *mgr.WorkerCtx) error { } // Retry quicker when failed. - module.selfcheckTask.Delay(selfcheckTaskRetryAfter) + module.selfcheckWorkerMgr.Delay(selfcheckTaskRetryAfter) return nil } diff --git a/service/core/base/logs.go b/service/core/base/logs.go index c04a048fe..8ceab9cd1 100644 --- a/service/core/base/logs.go +++ b/service/core/base/logs.go @@ -19,7 +19,7 @@ const ( ) func registerLogCleaner() { - _ = module.mgr.Delay("log cleaner", 15*time.Minute, logCleaner, nil).Repeat(24 * time.Hour) + _ = module.mgr.Delay("log cleaner", 15*time.Minute, logCleaner).Repeat(24 * time.Hour) } func logCleaner(_ *mgr.WorkerCtx) error { diff --git a/service/firewall/module.go b/service/firewall/module.go index 7439ce594..c7a9dd9af 100644 --- a/service/firewall/module.go +++ b/service/firewall/module.go @@ -46,6 +46,7 @@ func (f *Filter) Start(mgr *mgr.Manager) error { f.mgr = mgr if err := prep(); err != nil { + log.Errorf("Failed to prepare firewall module %q", err) return err } diff --git a/service/instance.go b/service/instance.go index 78eb8c79e..94183ce86 100644 --- a/service/instance.go +++ b/service/instance.go @@ -5,6 +5,7 @@ import ( "github.com/safing/portmaster/base/api" "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/database/dbmodule" "github.com/safing/portmaster/base/metrics" "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/base/rng" @@ -47,39 +48,41 @@ type Instance struct { version string - api *api.API + database *dbmodule.DBModule config *config.Config + api *api.API metrics *metrics.Metrics runtime *runtime.Runtime notifications *notifications.Notifications rng *rng.Rng base *base.Base + updates *updates.Updates + geoip *geoip.GeoIP + netenv *netenv.NetEnv + access *access.Access cabin *cabin.Cabin + navigator *navigator.Navigator captain *captain.Captain crew *crew.Crew docks *docks.Docks - navigator *navigator.Navigator patrol *patrol.Patrol ships *ships.Ships sluice *sluice.SluiceModule terminal *terminal.TerminalModule - updates *updates.Updates ui *ui.UI profile *profile.ProfileModule + network *network.Network + netquery *netquery.NetQuery filter *firewall.Filter interception *interception.Interception customlist *customlists.CustomList - geoip *geoip.GeoIP - netenv *netenv.NetEnv status *status.Status broadcasts *broadcasts.Broadcasts compat *compat.Compat nameserver *nameserver.NameServer - netquery *netquery.NetQuery - network *network.Network process *process.ProcessModule resolver *resolver.ResolverModule sync *sync.Sync @@ -96,6 +99,10 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { var err error // Base modules + instance.database, err = dbmodule.New(instance) + if err != nil { + return nil, fmt.Errorf("create config module: %w", err) + } instance.config, err = config.New(instance) if err != nil { return nil, fmt.Errorf("create config module: %w", err) @@ -125,6 +132,20 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { return nil, fmt.Errorf("create base module: %w", err) } + // Global service modules + instance.updates, err = updates.New(instance, svcCfg.ShutdownFunc) + if err != nil { + return nil, fmt.Errorf("create updates module: %w", err) + } + instance.geoip, err = geoip.New(instance) + if err != nil { + return nil, fmt.Errorf("create customlist module: %w", err) + } + instance.netenv, err = netenv.New(instance) + if err != nil { + return nil, fmt.Errorf("create netenv module: %w", err) + } + // SPN modules instance.access, err = access.New(instance) if err != nil { @@ -134,6 +155,10 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { if err != nil { return nil, fmt.Errorf("create cabin module: %w", err) } + instance.navigator, err = navigator.New(instance) + if err != nil { + return nil, fmt.Errorf("create navigator module: %w", err) + } instance.captain, err = captain.New(instance, svcCfg.ShutdownFunc) if err != nil { return nil, fmt.Errorf("create captain module: %w", err) @@ -146,10 +171,6 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { if err != nil { return nil, fmt.Errorf("create docks module: %w", err) } - instance.navigator, err = navigator.New(instance) - if err != nil { - return nil, fmt.Errorf("create navigator module: %w", err) - } instance.patrol, err = patrol.New(instance) if err != nil { return nil, fmt.Errorf("create patrol module: %w", err) @@ -168,10 +189,6 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { } // Service modules - instance.updates, err = updates.New(instance, svcCfg.ShutdownFunc) - if err != nil { - return nil, fmt.Errorf("create updates module: %w", err) - } instance.ui, err = ui.New(instance) if err != nil { return nil, fmt.Errorf("create ui module: %w", err) @@ -180,6 +197,14 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { if err != nil { return nil, fmt.Errorf("create profile module: %w", err) } + instance.network, err = network.New(instance) + if err != nil { + return nil, fmt.Errorf("create network module: %w", err) + } + instance.netquery, err = netquery.NewModule(instance) + if err != nil { + return nil, fmt.Errorf("create netquery module: %w", err) + } instance.filter, err = firewall.New(instance) if err != nil { return nil, fmt.Errorf("create filter module: %w", err) @@ -192,14 +217,6 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { if err != nil { return nil, fmt.Errorf("create customlist module: %w", err) } - instance.geoip, err = geoip.New(instance) - if err != nil { - return nil, fmt.Errorf("create customlist module: %w", err) - } - instance.netenv, err = netenv.New(instance) - if err != nil { - return nil, fmt.Errorf("create netenv module: %w", err) - } instance.status, err = status.New(instance) if err != nil { return nil, fmt.Errorf("create status module: %w", err) @@ -216,14 +233,6 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { if err != nil { return nil, fmt.Errorf("create nameserver module: %w", err) } - instance.netquery, err = netquery.NewModule(instance) - if err != nil { - return nil, fmt.Errorf("create netquery module: %w", err) - } - instance.network, err = network.New(instance) - if err != nil { - return nil, fmt.Errorf("create network module: %w", err) - } instance.process, err = process.New(instance) if err != nil { return nil, fmt.Errorf("create process module: %w", err) @@ -243,6 +252,7 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { // Add all modules to instance group. instance.Group = mgr.NewGroup( + instance.database, instance.config, instance.api, instance.metrics, @@ -251,31 +261,32 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { instance.rng, instance.base, + instance.updates, + instance.geoip, + instance.netenv, + instance.access, instance.cabin, + instance.navigator, instance.captain, instance.crew, instance.docks, - instance.navigator, instance.patrol, instance.ships, instance.sluice, instance.terminal, - instance.updates, instance.ui, instance.profile, + instance.network, + instance.netquery, instance.filter, instance.interception, instance.customlist, - instance.geoip, - instance.netenv, instance.status, instance.broadcasts, instance.compat, instance.nameserver, - instance.netquery, - instance.network, instance.process, instance.resolver, instance.sync, @@ -299,6 +310,16 @@ func (i *Instance) Version() string { return i.version } +// Database returns the database module. +func (i *Instance) Database() *dbmodule.DBModule { + return i.database +} + +// Config returns the config module. +func (i *Instance) Config() *config.Config { + return i.config +} + // API returns the api module. func (i *Instance) API() *api.API { return i.api @@ -329,6 +350,21 @@ func (i *Instance) Base() *base.Base { return i.base } +// Updates returns the updates module. +func (i *Instance) Updates() *updates.Updates { + return i.updates +} + +// GeoIP returns the geoip module. +func (i *Instance) GeoIP() *geoip.GeoIP { + return i.geoip +} + +// NetEnv returns the netenv module. +func (i *Instance) NetEnv() *netenv.NetEnv { + return i.netenv +} + // Access returns the access module. func (i *Instance) Access() *access.Access { return i.access @@ -379,21 +415,11 @@ func (i *Instance) Terminal() *terminal.TerminalModule { return i.terminal } -// Updates returns the updates module. -func (i *Instance) Updates() *updates.Updates { - return i.updates -} - // UI returns the ui module. func (i *Instance) UI() *ui.UI { return i.ui } -// Config returns the config module. -func (i *Instance) Config() *config.Config { - return i.config -} - // Profile returns the profile module. func (i *Instance) Profile() *profile.ProfileModule { return i.profile @@ -414,11 +440,6 @@ func (i *Instance) CustomList() *customlists.CustomList { return i.customlist } -// NetEnv returns the netenv module. -func (i *Instance) NetEnv() *netenv.NetEnv { - return i.netenv -} - // Status returns the status module. func (i *Instance) Status() *status.Status { return i.status @@ -439,7 +460,7 @@ func (i *Instance) NameServer() *nameserver.NameServer { return i.nameserver } -// NetQuery returns the newquery module. +// NetQuery returns the netquery module. func (i *Instance) NetQuery() *netquery.NetQuery { return i.netquery } diff --git a/service/intel/customlists/module.go b/service/intel/customlists/module.go index a21250a44..7401ed810 100644 --- a/service/intel/customlists/module.go +++ b/service/intel/customlists/module.go @@ -21,7 +21,7 @@ type CustomList struct { mgr *mgr.Manager instance instance - updateFilterListTask *mgr.Task + updateFilterListWorkerMgr *mgr.WorkerMgr States *mgr.StateMgr } @@ -30,7 +30,7 @@ func (cl *CustomList) Start(m *mgr.Manager) error { cl.mgr = m cl.States = mgr.NewStateMgr(m) - cl.updateFilterListTask = m.NewTask("update custom filter list", checkAndUpdateFilterList, nil) + cl.updateFilterListWorkerMgr = m.NewWorkerMgr("update custom filter list", checkAndUpdateFilterList, nil) if err := prep(); err != nil { return err @@ -101,7 +101,7 @@ func start() error { ) // Create parser task and enqueue for execution. "checkAndUpdateFilterList" will schedule the next execution. - module.updateFilterListTask.Delay(20 * time.Second).Repeat(1 * time.Minute) + module.updateFilterListWorkerMgr.Delay(20 * time.Second).Repeat(1 * time.Minute) return nil } diff --git a/service/intel/geoip/database.go b/service/intel/geoip/database.go index 1bffed787..6aee3d944 100644 --- a/service/intel/geoip/database.go +++ b/service/intel/geoip/database.go @@ -148,7 +148,7 @@ func (upd *updateWorker) triggerUpdate() { func (upd *updateWorker) start() { upd.once.Do(func() { - module.mgr.Delay("geoip-updater", time.Second*10, upd.run, nil) + module.mgr.Delay("geoip-updater", time.Second*10, upd.run) }) } diff --git a/service/netenv/main.go b/service/netenv/main.go index 35b231278..f6dcef9d7 100644 --- a/service/netenv/main.go +++ b/service/netenv/main.go @@ -25,6 +25,9 @@ type NetEnv struct { } func (ne *NetEnv) Start(m *mgr.Manager) error { + ne.EventNetworkChange = mgr.NewEventMgr[struct{}]("network change", m) + ne.EventOnlineStatusChange = mgr.NewEventMgr[OnlineStatus]("online status change", m) + if err := prep(); err != nil { return err } @@ -90,10 +93,6 @@ func New(instance instance) (*NetEnv, error) { return nil, errors.New("only one instance allowed") } - if err := prep(); err != nil { - return nil, err - } - module = &NetEnv{ instance: instance, } diff --git a/service/netquery/module_api.go b/service/netquery/module_api.go index 9f1fb6ca9..c0d6a0a1a 100644 --- a/service/netquery/module_api.go +++ b/service/netquery/module_api.go @@ -258,9 +258,9 @@ func (nq *NetQuery) Start(m *mgr.Manager) error { } }) - nq.mgr.Delay("network history cleaner", 10*time.Minute, func(w *mgr.WorkerCtx) error { + nq.mgr.Delay("network history cleaner delay", 10*time.Minute, func(w *mgr.WorkerCtx) error { return nq.Store.CleanupHistory(w.Ctx()) - }, nil).Repeat(1 * time.Hour) + }).Repeat(1 * time.Hour) // For debugging, provide a simple direct SQL query interface using // the runtime database. diff --git a/service/profile/config-update.go b/service/profile/config-update.go index 1c7382413..94c64cf0d 100644 --- a/service/profile/config-update.go +++ b/service/profile/config-update.go @@ -137,7 +137,7 @@ func updateGlobalConfigProfile(_ context.Context) error { _ = module.mgr.Delay("retry updating global config profile", 15*time.Second, func(w *mgr.WorkerCtx) error { return updateGlobalConfigProfile(w.Ctx()) - }, nil) + }) // Add module warning to inform user. module.States.Add(mgr.State{ diff --git a/service/profile/fingerprint.go b/service/profile/fingerprint.go index 2185ac33e..fc1e11a33 100644 --- a/service/profile/fingerprint.go +++ b/service/profile/fingerprint.go @@ -8,7 +8,7 @@ import ( "golang.org/x/exp/slices" "github.com/safing/jess/lhash" - "github.com/safing/portmaster/base/container" + "github.com/safing/structures/container" ) // # Matching and Scores diff --git a/service/resolver/failing.go b/service/resolver/failing.go index e950f1e31..bcb04a45a 100644 --- a/service/resolver/failing.go +++ b/service/resolver/failing.go @@ -74,7 +74,7 @@ func checkFailingResolvers(wc *mgr.WorkerCtx) error { var resolvers []*Resolver // Set next execution time. - module.failingResolverTask.Delay(time.Duration(nameserverRetryRate()) * time.Second) + module.failingResolverWorkerMgr.Delay(time.Duration(nameserverRetryRate()) * time.Second) // Make a copy of the resolver list. func() { diff --git a/service/resolver/main.go b/service/resolver/main.go index 5b77be84d..a15486111 100644 --- a/service/resolver/main.go +++ b/service/resolver/main.go @@ -25,8 +25,8 @@ type ResolverModule struct { mgr *mgr.Manager instance instance - failingResolverTask *mgr.Task - suggestUsingStaleCacheTask *mgr.Task + failingResolverWorkerMgr *mgr.WorkerMgr + suggestUsingStaleCacheTask *mgr.WorkerMgr States *mgr.StateMgr } @@ -102,8 +102,8 @@ func start() error { }) // Check failing resolvers regularly and when the network changes. - module.failingResolverTask = module.mgr.NewTask("check failing resolvers", checkFailingResolvers, nil) - module.failingResolverTask.Go() + module.failingResolverWorkerMgr = module.mgr.NewWorkerMgr("check failing resolvers", checkFailingResolvers, nil) + module.failingResolverWorkerMgr.Go() module.instance.NetEnv().EventNetworkChange.AddCallback( "check failing resolvers", func(wc *mgr.WorkerCtx, _ struct{}) (bool, error) { @@ -111,7 +111,7 @@ func start() error { return false, nil }) - module.suggestUsingStaleCacheTask = module.mgr.NewTask("suggest using stale cache", suggestUsingStaleCacheTask, nil) + module.suggestUsingStaleCacheTask = module.mgr.NewWorkerMgr("suggest using stale cache", suggestUsingStaleCacheTask, nil) module.suggestUsingStaleCacheTask.Go() module.mgr.Go( diff --git a/service/sync/util.go b/service/sync/util.go index bbd09e350..7fd3fb0c2 100644 --- a/service/sync/util.go +++ b/service/sync/util.go @@ -10,8 +10,8 @@ import ( "github.com/safing/jess/filesig" "github.com/safing/portmaster/base/api" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/structures/container" ) // Type is the type of an export. diff --git a/service/ui/module.go b/service/ui/module.go index 54ef93838..9c74c4d17 100644 --- a/service/ui/module.go +++ b/service/ui/module.go @@ -45,6 +45,11 @@ type UI struct { // Start starts the module. func (ui *UI) Start(m *mgr.Manager) error { ui.mgr = m + + if err := prep(); err != nil { + return err + } + return start() } diff --git a/service/updates/main.go b/service/updates/main.go index 55a47d5c2..d445f9e46 100644 --- a/service/updates/main.go +++ b/service/updates/main.go @@ -111,7 +111,7 @@ func prep() error { func start() error { initConfig() - _ = module.mgr.Repeat("automatic restart", 10*time.Minute, automaticRestart, nil) + _ = module.mgr.Repeat("automatic restart", 10*time.Minute, automaticRestart) module.instance.Config().EventConfigChange.AddCallback("update registry config", updateRegistryConfig) @@ -190,14 +190,14 @@ func start() error { } // start updater task - module.updateTask = module.mgr.NewTask("updater", checkForUpdates, nil) + module.updateWorkerMgr = module.mgr.NewWorkerMgr("updater", checkForUpdates, nil) if !disableTaskSchedule { - _ = module.updateTask.Repeat(30 * time.Minute) + _ = module.updateWorkerMgr.Repeat(30 * time.Minute) } if updateASAP { - module.updateTask.Go() + module.updateWorkerMgr.Go() } // react to upgrades @@ -225,7 +225,7 @@ func TriggerUpdate(forceIndexCheck, downloadAll bool) error { } // If index check if forced, start quicker. - module.updateTask.Go() + module.updateWorkerMgr.Go() } log.Debugf("updates: triggering update to run as soon as possible") diff --git a/service/updates/module.go b/service/updates/module.go index ef5858eba..6372b83d5 100644 --- a/service/updates/module.go +++ b/service/updates/module.go @@ -16,7 +16,7 @@ type Updates struct { instance instance shutdownFunc func(exitCode int) - updateTask *mgr.Task + updateWorkerMgr *mgr.WorkerMgr EventResourcesUpdated *mgr.EventMgr[struct{}] EventVersionsUpdated *mgr.EventMgr[struct{}] diff --git a/service/updates/restart.go b/service/updates/restart.go index f219b1ef1..64129a685 100644 --- a/service/updates/restart.go +++ b/service/updates/restart.go @@ -8,7 +8,6 @@ import ( "github.com/tevino/abool" - "github.com/safing/portbase/modules" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/mgr" ) @@ -23,7 +22,7 @@ var ( // should be restarted automatically when triggering a restart internally. RebootOnRestart bool - restartTask *modules.Task + restartWorkerMgr *mgr.WorkerMgr restartPending = abool.New() restartTriggered = abool.New() @@ -61,7 +60,8 @@ func DelayedRestart(delay time.Duration) { // Schedule the restart task. log.Warningf("updates: restart triggered, will execute in %s", delay) restartAt := time.Now().Add(delay) - restartTask.Schedule(restartAt) + // FIXME(vladimir): provide restart task + // restartTask.Schedule(restartAt) // Set restartTime. restartTimeLock.Lock() @@ -75,7 +75,8 @@ func AbortRestart() { log.Warningf("updates: restart aborted") // Cancel schedule. - restartTask.Schedule(time.Time{}) + // FIXME(vladimir): provide restart task + // restartTask.Schedule(time.Time{}) } } @@ -83,7 +84,8 @@ func AbortRestart() { // This can be used to prepone a scheduled restart if the conditions are preferable. func TriggerRestartIfPending() { if restartPending.IsSet() { - restartTask.StartASAP() + // FIXME(vladimir): provide restart task + // restartTask.StartASAP() } } @@ -91,7 +93,8 @@ func TriggerRestartIfPending() { // This only works if the process is managed by portmaster-start. func RestartNow() { restartPending.Set() - restartTask.StartASAP() + // FIXME(vladimir): provide restart task + // restartTask.StartASAP() } func automaticRestart(w *mgr.WorkerCtx) error { diff --git a/service/updates/upgrader.go b/service/updates/upgrader.go index 093647d65..622b3909b 100644 --- a/service/updates/upgrader.go +++ b/service/updates/upgrader.go @@ -182,14 +182,14 @@ func upgradeHub() error { // Increase update checks in order to detect aborts better. if !disableTaskSchedule { - module.updateTask.Repeat(10 * time.Minute) + module.updateWorkerMgr.Repeat(10 * time.Minute) } } else { AbortRestart() // Set update task schedule back to normal. if !disableTaskSchedule { - module.updateTask.Repeat(updateTaskRepeatDuration) + module.updateWorkerMgr.Repeat(updateTaskRepeatDuration) } } diff --git a/spn/access/module.go b/spn/access/module.go index 5db5a44c1..95ab8bcce 100644 --- a/spn/access/module.go +++ b/spn/access/module.go @@ -20,7 +20,7 @@ type Access struct { mgr *mgr.Manager instance instance - updateAccountTask *mgr.Task + updateAccountWorkerMgr *mgr.WorkerMgr EventAccountUpdate *mgr.EventMgr[struct{}] } @@ -28,7 +28,7 @@ type Access struct { func (a *Access) Start(m *mgr.Manager) error { a.mgr = m a.EventAccountUpdate = mgr.NewEventMgr[struct{}](AccountUpdateEvent, m) - a.updateAccountTask = m.NewTask("update account", UpdateAccount, nil) + a.updateAccountWorkerMgr = m.NewWorkerMgr("update account", UpdateAccount, nil) if err := prep(); err != nil { return err @@ -87,7 +87,7 @@ func start() error { loadTokens() // Register new task. - module.updateAccountTask.Delay(1 * time.Minute) + module.updateAccountWorkerMgr.Delay(1 * time.Minute) } return nil @@ -108,12 +108,12 @@ func stop() error { // UpdateAccount updates the user account and fetches new tokens, if needed. func UpdateAccount(_ *mgr.WorkerCtx) error { //, task *modules.Task) error { // Schedule next call this will change if other conditions are met bellow. - module.updateAccountTask.Delay(24 * time.Hour) + module.updateAccountWorkerMgr.Delay(24 * time.Hour) // Retry sooner if the token issuer is failing. defer func() { if tokenIssuerIsFailing.IsSet() { - module.updateAccountTask.Delay(tokenIssuerRetryDuration) + module.updateAccountWorkerMgr.Delay(tokenIssuerRetryDuration) } }() @@ -145,14 +145,14 @@ func UpdateAccount(_ *mgr.WorkerCtx) error { //, task *modules.Task) error { case time.Until(*u.Subscription.EndsAt) < 24*time.Hour && time.Since(*u.Subscription.EndsAt) < 24*time.Hour: // Update account every hour for 24h hours before and after the subscription ends. - module.updateAccountTask.Delay(1 * time.Hour) + module.updateAccountWorkerMgr.Delay(1 * time.Hour) case u.Subscription.NextBillingDate == nil: // No auto-subscription. case time.Until(*u.Subscription.NextBillingDate) < 24*time.Hour && time.Since(*u.Subscription.NextBillingDate) < 24*time.Hour: // Update account every hour 24h hours before and after the next billing date. - module.updateAccountTask.Delay(1 * time.Hour) + module.updateAccountWorkerMgr.Delay(1 * time.Hour) } return nil @@ -186,7 +186,7 @@ func tokenIssuerFailed() { // return // } - module.updateAccountTask.Delay(tokenIssuerRetryDuration) + module.updateAccountWorkerMgr.Delay(tokenIssuerRetryDuration) } // IsLoggedIn returns whether a User is currently logged in. @@ -216,10 +216,6 @@ func New(instance instance) (*Access, error) { return nil, errors.New("only one instance allowed") } - if err := prep(); err != nil { - return nil, err - } - module = &Access{ instance: instance, } diff --git a/spn/access/op_auth.go b/spn/access/op_auth.go index 9ce642743..f31986f78 100644 --- a/spn/access/op_auth.go +++ b/spn/access/op_auth.go @@ -3,10 +3,10 @@ package access import ( "time" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/access/token" "github.com/safing/portmaster/spn/terminal" + "github.com/safing/structures/container" ) // OpTypeAccessCodeAuth is the type ID of the auth operation. diff --git a/spn/access/token/pblind.go b/spn/access/token/pblind.go index 10e9bbfd2..97aa18922 100644 --- a/spn/access/token/pblind.go +++ b/spn/access/token/pblind.go @@ -13,8 +13,8 @@ import ( "github.com/mr-tron/base58" "github.com/rot256/pblind" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/structures/container" ) const pblindSecretSize = 32 diff --git a/spn/access/token/token.go b/spn/access/token/token.go index 9b615b1c4..a3d9c56bf 100644 --- a/spn/access/token/token.go +++ b/spn/access/token/token.go @@ -8,7 +8,7 @@ import ( "github.com/mr-tron/base58" - "github.com/safing/portmaster/base/container" + "github.com/safing/structures/container" ) // Token represents a token, consisting of a zone (name) and some data. diff --git a/spn/access/zones.go b/spn/access/zones.go index 585756735..0e550785f 100644 --- a/spn/access/zones.go +++ b/spn/access/zones.go @@ -139,15 +139,8 @@ func initializeTestZone() error { } func shouldRequestTokensHandler(_ token.Handler) { - // accountUpdateTask is always set in client mode and when the module is online. - // Check if it's set in case this gets executed in other circumstances. - // if accountUpdateTask == nil { - // log.Warningf("spn/access: trying to trigger account update, but the task is not available") - // return - // } - - // accountUpdateTask.StartASAP() - module.mgr.Go("update account", UpdateAccount) + // Run the account update task as now. + module.updateAccountWorkerMgr.Go() } // GetTokenAmount returns the amount of tokens for the given zones. diff --git a/spn/captain/module.go b/spn/captain/module.go index f71871f2b..4c82c1fc2 100644 --- a/spn/captain/module.go +++ b/spn/captain/module.go @@ -36,7 +36,7 @@ type Captain struct { shutdownFunc func(exitCode int) healthCheckTicker *mgr.SleepyTicker - maintainPublicStatus *mgr.Task + maintainPublicStatus *mgr.WorkerMgr States *mgr.StateMgr EventSPNConnected *mgr.EventMgr[struct{}] @@ -44,8 +44,10 @@ type Captain struct { func (c *Captain) Start(m *mgr.Manager) error { c.mgr = m + c.States = mgr.NewStateMgr(m) c.EventSPNConnected = mgr.NewEventMgr[struct{}](SPNConnectedEvent, m) - c.maintainPublicStatus = m.NewTask("maintain public status", maintainPublicStatus, nil) + c.maintainPublicStatus = m.NewWorkerMgr("maintain public status", maintainPublicStatus, nil) + if err := prep(); err != nil { return err } @@ -179,7 +181,7 @@ func start() error { // network optimizer if conf.PublicHub() { - module.mgr.Delay("optimize network", 15*time.Second, optimizeNetwork, nil).Repeat(1 * time.Minute) + module.mgr.Delay("optimize network delay", 15*time.Second, optimizeNetwork).Repeat(1 * time.Minute) } // client + home hub manager diff --git a/spn/captain/op_gossip.go b/spn/captain/op_gossip.go index 4e6866044..de5edaa4d 100644 --- a/spn/captain/op_gossip.go +++ b/spn/captain/op_gossip.go @@ -3,13 +3,13 @@ package captain import ( "time" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/docks" "github.com/safing/portmaster/spn/hub" "github.com/safing/portmaster/spn/terminal" + "github.com/safing/structures/container" ) // GossipOpType is the type ID of the gossip operation. diff --git a/spn/captain/op_gossip_query.go b/spn/captain/op_gossip_query.go index bc7e6b7e7..27d605a8d 100644 --- a/spn/captain/op_gossip_query.go +++ b/spn/captain/op_gossip_query.go @@ -5,7 +5,6 @@ import ( "strings" "time" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/mgr" @@ -13,6 +12,7 @@ import ( "github.com/safing/portmaster/spn/docks" "github.com/safing/portmaster/spn/hub" "github.com/safing/portmaster/spn/terminal" + "github.com/safing/structures/container" ) // GossipQueryOpType is the type ID of the gossip query operation. diff --git a/spn/captain/op_publish.go b/spn/captain/op_publish.go index c1fd29e70..5a16ddc32 100644 --- a/spn/captain/op_publish.go +++ b/spn/captain/op_publish.go @@ -3,12 +3,12 @@ package captain import ( "time" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/spn/cabin" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/docks" "github.com/safing/portmaster/spn/hub" "github.com/safing/portmaster/spn/terminal" + "github.com/safing/structures/container" ) // PublishOpType is the type ID of the publish operation. diff --git a/spn/crew/module.go b/spn/crew/module.go index 3523a10b6..54ab7051b 100644 --- a/spn/crew/module.go +++ b/spn/crew/module.go @@ -24,7 +24,7 @@ func (c *Crew) Stop(m *mgr.Manager) error { } func start() error { - _ = module.mgr.Repeat("sticky cleaner", 10*time.Minute, cleanStickyHubs, nil) + _ = module.mgr.Repeat("sticky cleaner", 10*time.Minute, cleanStickyHubs) return registerMetrics() } diff --git a/spn/crew/op_connect.go b/spn/crew/op_connect.go index 82079d9f3..394c4fc57 100644 --- a/spn/crew/op_connect.go +++ b/spn/crew/op_connect.go @@ -10,7 +10,6 @@ import ( "sync/atomic" "time" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/mgr" @@ -18,6 +17,7 @@ import ( "github.com/safing/portmaster/service/network/packet" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/terminal" + "github.com/safing/structures/container" ) // ConnectOpType is the type ID for the connection operation. diff --git a/spn/crew/op_ping.go b/spn/crew/op_ping.go index 2976fd611..eb8b240f4 100644 --- a/spn/crew/op_ping.go +++ b/spn/crew/op_ping.go @@ -4,10 +4,10 @@ import ( "crypto/subtle" "time" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/rng" "github.com/safing/portmaster/spn/terminal" + "github.com/safing/structures/container" ) const ( diff --git a/spn/docks/bandwidth_test.go b/spn/docks/bandwidth_test.go index c526fa9e9..1924be694 100644 --- a/spn/docks/bandwidth_test.go +++ b/spn/docks/bandwidth_test.go @@ -6,9 +6,9 @@ import ( "github.com/tevino/abool" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/spn/terminal" + "github.com/safing/structures/container" ) func TestEffectiveBandwidth(t *testing.T) { //nolint:paralleltest // Run alone. diff --git a/spn/docks/controller.go b/spn/docks/controller.go index 9a1beb562..6fc46b33f 100644 --- a/spn/docks/controller.go +++ b/spn/docks/controller.go @@ -1,8 +1,8 @@ package docks import ( - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/spn/terminal" + "github.com/safing/structures/container" ) // CraneControllerTerminal is a terminal for the crane itself. diff --git a/spn/docks/crane.go b/spn/docks/crane.go index 65e27147e..4d2a84bda 100644 --- a/spn/docks/crane.go +++ b/spn/docks/crane.go @@ -12,7 +12,6 @@ import ( "github.com/tevino/abool" "github.com/safing/jess" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/rng" @@ -21,6 +20,7 @@ import ( "github.com/safing/portmaster/spn/hub" "github.com/safing/portmaster/spn/ships" "github.com/safing/portmaster/spn/terminal" + "github.com/safing/structures/container" ) const ( diff --git a/spn/docks/crane_establish.go b/spn/docks/crane_establish.go index d03896fe8..4759f445a 100644 --- a/spn/docks/crane_establish.go +++ b/spn/docks/crane_establish.go @@ -3,10 +3,10 @@ package docks import ( "time" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/terminal" + "github.com/safing/structures/container" ) const ( diff --git a/spn/docks/crane_init.go b/spn/docks/crane_init.go index 807414119..9bf08773d 100644 --- a/spn/docks/crane_init.go +++ b/spn/docks/crane_init.go @@ -5,13 +5,13 @@ import ( "time" "github.com/safing/jess" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/base/info" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/terminal" + "github.com/safing/structures/container" ) /* diff --git a/spn/docks/crane_terminal.go b/spn/docks/crane_terminal.go index 5a6d7a53b..4ed392a97 100644 --- a/spn/docks/crane_terminal.go +++ b/spn/docks/crane_terminal.go @@ -3,9 +3,9 @@ package docks import ( "net" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/spn/hub" "github.com/safing/portmaster/spn/terminal" + "github.com/safing/structures/container" ) // CraneTerminal is a terminal started by a crane. diff --git a/spn/docks/crane_verify.go b/spn/docks/crane_verify.go index cb2c86620..f6f976a7a 100644 --- a/spn/docks/crane_verify.go +++ b/spn/docks/crane_verify.go @@ -6,10 +6,10 @@ import ( "fmt" "time" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/spn/cabin" "github.com/safing/portmaster/spn/terminal" + "github.com/safing/structures/container" ) const ( diff --git a/spn/docks/op_capacity.go b/spn/docks/op_capacity.go index 9c924d97d..e26aec352 100644 --- a/spn/docks/op_capacity.go +++ b/spn/docks/op_capacity.go @@ -7,11 +7,11 @@ import ( "github.com/tevino/abool" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/terminal" + "github.com/safing/structures/container" ) const ( diff --git a/spn/docks/op_expand.go b/spn/docks/op_expand.go index f4c69a747..1c01ad544 100644 --- a/spn/docks/op_expand.go +++ b/spn/docks/op_expand.go @@ -8,10 +8,10 @@ import ( "github.com/tevino/abool" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/terminal" + "github.com/safing/structures/container" ) // ExpandOpType is the type ID of the expand operation. diff --git a/spn/docks/op_latency.go b/spn/docks/op_latency.go index 7e2c19339..59681fc12 100644 --- a/spn/docks/op_latency.go +++ b/spn/docks/op_latency.go @@ -5,12 +5,12 @@ import ( "fmt" "time" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/rng" "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/terminal" + "github.com/safing/structures/container" ) const ( diff --git a/spn/docks/op_sync_state.go b/spn/docks/op_sync_state.go index 72bb04f3e..c5303544e 100644 --- a/spn/docks/op_sync_state.go +++ b/spn/docks/op_sync_state.go @@ -4,11 +4,11 @@ import ( "context" "time" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/terminal" + "github.com/safing/structures/container" ) // SyncStateOpType is the type ID of the sync state operation. diff --git a/spn/docks/op_whoami.go b/spn/docks/op_whoami.go index 9f6ce8606..53ca914d6 100644 --- a/spn/docks/op_whoami.go +++ b/spn/docks/op_whoami.go @@ -3,9 +3,9 @@ package docks import ( "time" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/spn/terminal" + "github.com/safing/structures/container" ) const ( diff --git a/spn/docks/terminal_expansion.go b/spn/docks/terminal_expansion.go index a04f93cb3..31ba0480a 100644 --- a/spn/docks/terminal_expansion.go +++ b/spn/docks/terminal_expansion.go @@ -7,9 +7,9 @@ import ( "github.com/tevino/abool" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/spn/hub" "github.com/safing/portmaster/spn/terminal" + "github.com/safing/structures/container" ) // ExpansionTerminal is used for expanding to another Hub. diff --git a/spn/hub/update.go b/spn/hub/update.go index 98ca680c3..0d6a9efd9 100644 --- a/spn/hub/update.go +++ b/spn/hub/update.go @@ -7,11 +7,11 @@ import ( "github.com/safing/jess" "github.com/safing/jess/lhash" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/database" "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/network/netutils" + "github.com/safing/structures/container" ) var ( diff --git a/spn/navigator/module.go b/spn/navigator/module.go index 75c44bcb9..41a2cc703 100644 --- a/spn/navigator/module.go +++ b/spn/navigator/module.go @@ -110,12 +110,12 @@ geoInitCheck: } // TODO: delete superseded hubs after x amount of time - _ = module.mgr.Delay("update states", 3*time.Minute, Main.updateStates, nil).Repeat(1 * time.Hour) - _ = module.mgr.Delay("update failing states delay", 3*time.Minute, Main.updateFailingStates, nil).Repeat(1 * time.Minute) + _ = module.mgr.Delay("update states", 3*time.Minute, Main.updateStates).Repeat(1 * time.Hour) + _ = module.mgr.Delay("update failing states", 3*time.Minute, Main.updateFailingStates).Repeat(1 * time.Minute) if conf.PublicHub() { // Only measure Hubs on public Hubs. - module.mgr.Delay("measure hubs delay", 5*time.Minute, Main.measureHubs, nil).Repeat(1 * time.Minute) + module.mgr.Delay("measure hubs", 5*time.Minute, Main.measureHubs).Repeat(1 * time.Minute) // Only register metrics on Hubs, as they only make sense there. err := registerMetrics() diff --git a/spn/patrol/module.go b/spn/patrol/module.go index 78211a9be..414c06165 100644 --- a/spn/patrol/module.go +++ b/spn/patrol/module.go @@ -24,7 +24,7 @@ func (p *Patrol) Start(m *mgr.Manager) error { p.EventChangeSignal = mgr.NewEventMgr[struct{}](ChangeSignalEventName, m) if conf.PublicHub() { - m.Repeat("connectivity test", 5*time.Minute, connectivityCheckTask, nil) + m.Repeat("connectivity test", 5*time.Minute, connectivityCheckTask) } return nil } diff --git a/spn/terminal/init.go b/spn/terminal/init.go index 3c6ce921a..a437e9f5f 100644 --- a/spn/terminal/init.go +++ b/spn/terminal/init.go @@ -4,11 +4,11 @@ import ( "context" "github.com/safing/jess" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/spn/cabin" "github.com/safing/portmaster/spn/hub" + "github.com/safing/structures/container" ) /* diff --git a/spn/terminal/module.go b/spn/terminal/module.go index 3caeaf82a..d032ccd03 100644 --- a/spn/terminal/module.go +++ b/spn/terminal/module.go @@ -18,6 +18,7 @@ type TerminalModule struct { } func (s *TerminalModule) Start(m *mgr.Manager) error { + s.mgr = m return start() } @@ -26,7 +27,7 @@ func (s *TerminalModule) Stop(m *mgr.Manager) error { } var ( - rngFeeder *rng.Feeder = rng.NewFeeder() + rngFeeder *rng.Feeder = nil scheduler *unit.Scheduler diff --git a/spn/terminal/msg.go b/spn/terminal/msg.go index 764601e87..1a2754996 100644 --- a/spn/terminal/msg.go +++ b/spn/terminal/msg.go @@ -4,8 +4,8 @@ import ( "fmt" "runtime" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/spn/unit" + "github.com/safing/structures/container" ) // Msg is a message within the SPN network stack. diff --git a/spn/terminal/msgtypes.go b/spn/terminal/msgtypes.go index a7d244b34..fba9d3235 100644 --- a/spn/terminal/msgtypes.go +++ b/spn/terminal/msgtypes.go @@ -1,8 +1,8 @@ package terminal import ( - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/formats/varint" + "github.com/safing/structures/container" ) /* diff --git a/spn/terminal/operation.go b/spn/terminal/operation.go index 2f58ce639..7e014e0ec 100644 --- a/spn/terminal/operation.go +++ b/spn/terminal/operation.go @@ -7,10 +7,10 @@ import ( "github.com/tevino/abool" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/utils" "github.com/safing/portmaster/service/mgr" + "github.com/safing/structures/container" ) // Operation is an interface for all operations. diff --git a/spn/terminal/operation_counter.go b/spn/terminal/operation_counter.go index 1609a9455..687ade535 100644 --- a/spn/terminal/operation_counter.go +++ b/spn/terminal/operation_counter.go @@ -5,11 +5,11 @@ import ( "sync" "time" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/mgr" + "github.com/safing/structures/container" ) // CounterOpType is the type ID for the Counter Operation. diff --git a/spn/terminal/terminal.go b/spn/terminal/terminal.go index e4b044f9e..1288303b3 100644 --- a/spn/terminal/terminal.go +++ b/spn/terminal/terminal.go @@ -9,12 +9,12 @@ import ( "github.com/tevino/abool" "github.com/safing/jess" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/base/modules" "github.com/safing/portmaster/base/rng" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/cabin" "github.com/safing/portmaster/spn/conf" + "github.com/safing/structures/container" ) const ( diff --git a/spn/terminal/terminal_test.go b/spn/terminal/terminal_test.go index 7d0a83434..1aeee14b3 100644 --- a/spn/terminal/terminal_test.go +++ b/spn/terminal/terminal_test.go @@ -8,9 +8,9 @@ import ( "testing" "time" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/spn/cabin" "github.com/safing/portmaster/spn/hub" + "github.com/safing/structures/container" ) func TestTerminals(t *testing.T) { diff --git a/spn/terminal/testing.go b/spn/terminal/testing.go index eca8e9fd8..67c7dd8d9 100644 --- a/spn/terminal/testing.go +++ b/spn/terminal/testing.go @@ -4,10 +4,10 @@ import ( "context" "time" - "github.com/safing/portmaster/base/container" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/cabin" "github.com/safing/portmaster/spn/hub" + "github.com/safing/structures/container" ) const ( From 4d9b908f43f935c5996ac4a756bf9e0040fe4ecf Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov Date: Thu, 27 Jun 2024 12:28:44 +0300 Subject: [PATCH 17/56] [WIP] Fix start handing of the new module system --- cmds/portmaster-core/main.go | 77 ++++++++++++++++++++++++++++++++++-- service/firewall/module.go | 6 ++- service/instance.go | 12 +++--- service/mgr/module.go | 60 ++++++++++++++++++++++++---- service/mgr/states.go | 12 ++++-- service/mgr/worker.go | 2 +- service/mgr/workermgr.go | 10 ++++- service/profile/module.go | 3 +- service/updates/main.go | 8 +--- service/updates/module.go | 4 +- service/updates/restart.go | 13 ++---- 11 files changed, 163 insertions(+), 44 deletions(-) diff --git a/cmds/portmaster-core/main.go b/cmds/portmaster-core/main.go index 14c2d5379..b200a7d44 100644 --- a/cmds/portmaster-core/main.go +++ b/cmds/portmaster-core/main.go @@ -22,6 +22,8 @@ import ( _ "github.com/safing/portmaster/spn/captain" ) +var sigUSR1 = syscall.Signal(0xa) + func main() { flag.Parse() @@ -59,9 +61,78 @@ func main() { return } // Start - err = instance.Group.Start() + go func() { + err = instance.Group.Start() + if err != nil { + fmt.Printf("instance start failed: %s\n", err) + return + } + }() + + // Wait for signal. + signalCh := make(chan os.Signal, 1) + signal.Notify( + signalCh, + os.Interrupt, + syscall.SIGHUP, + syscall.SIGINT, + syscall.SIGTERM, + syscall.SIGQUIT, + sigUSR1, + ) + +signalLoop: + for { + select { + case sig := <-signalCh: + // Only print and continue to wait if SIGUSR1 + if sig == sigUSR1 { + printStackTo(os.Stderr, "PRINTING STACK ON REQUEST") + continue signalLoop + } + + fmt.Println(" ") // CLI output. + slog.Warn("program was interrupted, stopping") + + // catch signals during shutdown + go func() { + forceCnt := 5 + for { + <-signalCh + forceCnt-- + if forceCnt > 0 { + fmt.Printf(" again, but already shutting down - %d more to force\n", forceCnt) + } else { + printStackTo(os.Stderr, "PRINTING STACK ON FORCED EXIT") + os.Exit(1) + } + } + }() + + go func() { + time.Sleep(3 * time.Minute) + printStackTo(os.Stderr, "PRINTING STACK - TAKING TOO LONG FOR SHUTDOWN") + os.Exit(1) + }() + + if err := instance.Stop(); err != nil { + slog.Error("failed to stop portmaster", "err", err) + continue signalLoop + } + break signalLoop + + case <-instance.Done(): + break signalLoop + } + } +} + +func printStackTo(writer io.Writer, msg string) { + _, err := fmt.Fprintf(writer, "===== %s =====\n", msg) + if err == nil { + err = pprof.Lookup("goroutine").WriteTo(writer, 1) + } if err != nil { - fmt.Printf("instance start failed: %s\n", err) - return + slog.Error("failed to write stack trace", "err", err) } } diff --git a/service/firewall/module.go b/service/firewall/module.go index c7a9dd9af..ec5e4fb7b 100644 --- a/service/firewall/module.go +++ b/service/firewall/module.go @@ -104,7 +104,7 @@ func prep() error { return err } - return prepAPIAuth() + return nil } func start() error { @@ -136,6 +136,10 @@ func New(instance instance) (*Filter, error) { instance: instance, } + if err := prepAPIAuth(); err != nil { + return nil, err + } + return module, nil } diff --git a/service/instance.go b/service/instance.go index 94183ce86..b582ec2ee 100644 --- a/service/instance.go +++ b/service/instance.go @@ -57,6 +57,7 @@ type Instance struct { rng *rng.Rng base *base.Base + core *core.Core updates *updates.Updates geoip *geoip.GeoIP netenv *netenv.NetEnv @@ -86,7 +87,6 @@ type Instance struct { process *process.ProcessModule resolver *resolver.ResolverModule sync *sync.Sync - core *core.Core } // New returns a new portmaster service instance. @@ -133,6 +133,10 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { } // Global service modules + instance.core, err = core.New(instance) + if err != nil { + return nil, fmt.Errorf("create core module: %w", err) + } instance.updates, err = updates.New(instance, svcCfg.ShutdownFunc) if err != nil { return nil, fmt.Errorf("create updates module: %w", err) @@ -245,10 +249,6 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { if err != nil { return nil, fmt.Errorf("create sync module: %w", err) } - instance.core, err = core.New(instance) - if err != nil { - return nil, fmt.Errorf("create core module: %w", err) - } // Add all modules to instance group. instance.Group = mgr.NewGroup( @@ -261,6 +261,7 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { instance.rng, instance.base, + instance.core, instance.updates, instance.geoip, instance.netenv, @@ -290,7 +291,6 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { instance.process, instance.resolver, instance.sync, - instance.core, ) // FIXME: call this before to trigger shutdown/restart event diff --git a/service/mgr/module.go b/service/mgr/module.go index 019138af5..4c2d0e862 100644 --- a/service/mgr/module.go +++ b/service/mgr/module.go @@ -7,8 +7,34 @@ import ( "reflect" "strings" "sync" + "sync/atomic" ) +const ( + groupStateOff int32 = iota + groupStateStarting + groupStateRunning + groupStateStopping + groupStateInvalid +) + +func groupStateToString(state int32) string { + switch state { + case groupStateOff: + return "off" + case groupStateStarting: + return "starting" + case groupStateRunning: + return "running" + case groupStateStopping: + return "stopping" + case groupStateInvalid: + return "invalid" + } + + return "unknown" +} + // Group describes a group of modules. type Group struct { modules []*groupModule @@ -16,6 +42,8 @@ type Group struct { ctx context.Context cancelCtx context.CancelFunc ctxLock sync.Mutex + + state atomic.Int32 } type groupModule struct { @@ -64,22 +92,42 @@ func NewGroup(modules ...Module) *Group { // If a module fails to start, itself and all previous modules // will be stopped in the reverse order. func (g *Group) Start() error { + if !g.state.CompareAndSwap(groupStateOff, groupStateStarting) { + return fmt.Errorf("group is not off, state: %s", groupStateToString(g.state.Load())) + } + g.initGroupContext() for i, m := range g.modules { + m.mgr.Info("starting") err := m.module.Start(m.mgr) if err != nil { - g.stopFrom(i) + if !g.stopFrom(i) { + g.state.Store(groupStateInvalid) + } else { + g.state.Store(groupStateOff) + } return fmt.Errorf("failed to start %s: %w", makeModuleName(m.module), err) } m.mgr.Info("started") } + g.state.Store(groupStateRunning) return nil } // Stop stops all modules in the group in the reverse order. -func (g *Group) Stop() (ok bool) { - return g.stopFrom(len(g.modules) - 1) +func (g *Group) Stop() error { + if !g.state.CompareAndSwap(groupStateRunning, groupStateStopping) { + return fmt.Errorf("group is not running, state: %s", groupStateToString(g.state.Load())) + } + + if !g.stopFrom(len(g.modules) - 1) { + g.state.Store(groupStateInvalid) + return errors.New("failed to stop") + } + + g.state.Store(groupStateOff) + return nil } func (g *Group) stopFrom(index int) (ok bool) { @@ -150,11 +198,7 @@ func RunModules(ctx context.Context, modules ...Module) error { // Stop module when context is canceled. <-ctx.Done() - if !g.Stop() { - return errors.New("failed to stop") - } - - return nil + return g.Stop() } func makeModuleName(m Module) string { diff --git a/service/mgr/states.go b/service/mgr/states.go index 229c395c3..c7fc6d653 100644 --- a/service/mgr/states.go +++ b/service/mgr/states.go @@ -73,7 +73,7 @@ func (m *StateMgr) Add(s State) { m.states = append(m.states, s) } - m.statesEventMgr.Submit(m.Export()) + m.statesEventMgr.Submit(m.export()) } // Remove removes the state with the given ID. @@ -85,7 +85,7 @@ func (m *StateMgr) Remove(id string) { return s.ID == id }) - m.statesEventMgr.Submit(m.Export()) + m.statesEventMgr.Submit(m.export()) } // Clear removes all states. @@ -95,14 +95,18 @@ func (m *StateMgr) Clear() { m.states = nil - m.statesEventMgr.Submit(m.Export()) + m.statesEventMgr.Submit(m.export()) } -// Export returns the current states. func (m *StateMgr) Export() StateUpdate { m.statesLock.Lock() defer m.statesLock.Unlock() + return m.export() +} + +// Export returns the current states. +func (m *StateMgr) export() StateUpdate { name := "" if m.mgr != nil { name = m.mgr.name diff --git a/service/mgr/worker.go b/service/mgr/worker.go index e6fa0fb4f..802af72a2 100644 --- a/service/mgr/worker.go +++ b/service/mgr/worker.go @@ -124,7 +124,7 @@ func (w *WorkerCtx) LogAttrs(level slog.Level, msg string, attrs ...slog.Attr) { // - Panic catching. // - Flow control helpers. func (m *Manager) Go(name string, fn func(w *WorkerCtx) error) { - m.logger.Log(m.ctx, slog.LevelInfo, "worker started", "name", name) + // m.logger.Log(m.ctx, slog.LevelInfo, "worker started", "name", name) go m.manageWorker(name, fn) } diff --git a/service/mgr/workermgr.go b/service/mgr/workermgr.go index 9ba716d03..0b40a3a97 100644 --- a/service/mgr/workermgr.go +++ b/service/mgr/workermgr.go @@ -263,7 +263,10 @@ func (s *WorkerMgr) Delay(duration time.Duration) *WorkerMgr { defer s.actionLock.Unlock() s.delay.Stop() - s.delay = s.newDelay(duration) + s.delay = nil + if duration > 0 { + s.delay = s.newDelay(duration) + } s.check() return s @@ -276,7 +279,10 @@ func (s *WorkerMgr) Repeat(interval time.Duration) *WorkerMgr { defer s.actionLock.Unlock() s.repeat.Stop() - s.repeat = s.newRepeat(interval) + s.repeat = nil + if interval > 0 { + s.repeat = s.newRepeat(interval) + } s.check() return s diff --git a/service/profile/module.go b/service/profile/module.go index 960719d18..0233351ab 100644 --- a/service/profile/module.go +++ b/service/profile/module.go @@ -18,8 +18,7 @@ import ( ) var ( - migrations = migration.New("core:migrations/profile") - // module *modules.Module + migrations = migration.New("core:migrations/profile") updatesPath string ) diff --git a/service/updates/main.go b/service/updates/main.go index d445f9e46..fbdaa9dd7 100644 --- a/service/updates/main.go +++ b/service/updates/main.go @@ -80,11 +80,6 @@ const ( ) func init() { - // FIXME: - // module = modules.Register(ModuleName, prep, start, stop, "base") - // module.RegisterEvent(VersionUpdateEvent, true) - // module.RegisterEvent(ResourceUpdateEvent, true) - flag.StringVar(&updateServerFromFlag, "update-server", "", "set an alternative update server (full URL)") flag.StringVar(&userAgentFromFlag, "update-agent", "", "set an alternative user agent for requests to the update server") } @@ -111,8 +106,7 @@ func prep() error { func start() error { initConfig() - _ = module.mgr.Repeat("automatic restart", 10*time.Minute, automaticRestart) - + module.restartWorkerMgr = module.mgr.Repeat("automatic restart", 10*time.Minute, automaticRestart) module.instance.Config().EventConfigChange.AddCallback("update registry config", updateRegistryConfig) // create registry diff --git a/service/updates/module.go b/service/updates/module.go index 6372b83d5..9a8554c39 100644 --- a/service/updates/module.go +++ b/service/updates/module.go @@ -16,7 +16,8 @@ type Updates struct { instance instance shutdownFunc func(exitCode int) - updateWorkerMgr *mgr.WorkerMgr + updateWorkerMgr *mgr.WorkerMgr + restartWorkerMgr *mgr.WorkerMgr EventResourcesUpdated *mgr.EventMgr[struct{}] EventVersionsUpdated *mgr.EventMgr[struct{}] @@ -29,6 +30,7 @@ func (u *Updates) Start(m *mgr.Manager) error { u.mgr = m u.EventResourcesUpdated = mgr.NewEventMgr[struct{}](ResourceUpdateEvent, u.mgr) u.EventVersionsUpdated = mgr.NewEventMgr[struct{}](VersionUpdateEvent, u.mgr) + u.States = mgr.NewStateMgr(u.mgr) return start() diff --git a/service/updates/restart.go b/service/updates/restart.go index 64129a685..bed86ea84 100644 --- a/service/updates/restart.go +++ b/service/updates/restart.go @@ -22,7 +22,6 @@ var ( // should be restarted automatically when triggering a restart internally. RebootOnRestart bool - restartWorkerMgr *mgr.WorkerMgr restartPending = abool.New() restartTriggered = abool.New() @@ -60,8 +59,7 @@ func DelayedRestart(delay time.Duration) { // Schedule the restart task. log.Warningf("updates: restart triggered, will execute in %s", delay) restartAt := time.Now().Add(delay) - // FIXME(vladimir): provide restart task - // restartTask.Schedule(restartAt) + module.restartWorkerMgr.Delay(delay) // Set restartTime. restartTimeLock.Lock() @@ -75,8 +73,7 @@ func AbortRestart() { log.Warningf("updates: restart aborted") // Cancel schedule. - // FIXME(vladimir): provide restart task - // restartTask.Schedule(time.Time{}) + module.restartWorkerMgr.Delay(0) } } @@ -84,8 +81,7 @@ func AbortRestart() { // This can be used to prepone a scheduled restart if the conditions are preferable. func TriggerRestartIfPending() { if restartPending.IsSet() { - // FIXME(vladimir): provide restart task - // restartTask.StartASAP() + module.restartWorkerMgr.Go() } } @@ -93,8 +89,7 @@ func TriggerRestartIfPending() { // This only works if the process is managed by portmaster-start. func RestartNow() { restartPending.Set() - // FIXME(vladimir): provide restart task - // restartTask.StartASAP() + module.restartWorkerMgr.Go() } func automaticRestart(w *mgr.WorkerCtx) error { From b9edb7ea1f4804a343a2a836ec7bda6e7c2447de Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov Date: Tue, 2 Jul 2024 12:16:46 +0300 Subject: [PATCH 18/56] [WIP] Improve startup process --- cmds/portmaster-core/main.go | 7 ++ service/firewall/interception/module.go | 6 +- .../firewall/interception/nfqueue_linux.go | 10 +- service/mgr/module.go | 15 ++- service/mgr/worker.go | 2 +- service/network/clean.go | 3 +- spn/captain/module.go | 114 +++++++++--------- spn/navigator/module.go | 71 +++++------ 8 files changed, 126 insertions(+), 102 deletions(-) diff --git a/cmds/portmaster-core/main.go b/cmds/portmaster-core/main.go index b200a7d44..b61173cbe 100644 --- a/cmds/portmaster-core/main.go +++ b/cmds/portmaster-core/main.go @@ -4,7 +4,14 @@ package main import ( "flag" "fmt" + "io" + "log/slog" + "os" + "os/signal" "runtime" + "runtime/pprof" + "syscall" + "time" "github.com/safing/portmaster/base/info" "github.com/safing/portmaster/base/log" diff --git a/service/firewall/interception/module.go b/service/firewall/interception/module.go index 189d1bfcb..76964f1af 100644 --- a/service/firewall/interception/module.go +++ b/service/firewall/interception/module.go @@ -68,8 +68,10 @@ func stop() error { } close(metrics.done) - - return stopInterception() + if err := stopInterception(); err != nil { + log.Errorf("failed to stop interception module: %s", err) + } + return nil } var ( diff --git a/service/firewall/interception/nfqueue_linux.go b/service/firewall/interception/nfqueue_linux.go index cbaad7cce..bff94fe28 100644 --- a/service/firewall/interception/nfqueue_linux.go +++ b/service/firewall/interception/nfqueue_linux.go @@ -258,30 +258,30 @@ func StartNfqueueInterception(packets chan<- packet.Packet) (err error) { err = activateNfqueueFirewall() if err != nil { - _ = StopNfqueueInterception() + // _ = StopNfqueueInterception() return fmt.Errorf("could not initialize nfqueue: %w", err) } out4Queue, err = nfq.New(17040, false) if err != nil { - _ = StopNfqueueInterception() + // _ = StopNfqueueInterception() return fmt.Errorf("nfqueue(IPv4, out): %w", err) } in4Queue, err = nfq.New(17140, false) if err != nil { - _ = StopNfqueueInterception() + // _ = StopNfqueueInterception() return fmt.Errorf("nfqueue(IPv4, in): %w", err) } if netenv.IPv6Enabled() { out6Queue, err = nfq.New(17060, true) if err != nil { - _ = StopNfqueueInterception() + // _ = StopNfqueueInterception() return fmt.Errorf("nfqueue(IPv6, out): %w", err) } in6Queue, err = nfq.New(17160, true) if err != nil { - _ = StopNfqueueInterception() + // _ = StopNfqueueInterception() return fmt.Errorf("nfqueue(IPv6, in): %w", err) } } else { diff --git a/service/mgr/module.go b/service/mgr/module.go index 4c2d0e862..e900eb2a8 100644 --- a/service/mgr/module.go +++ b/service/mgr/module.go @@ -8,6 +8,7 @@ import ( "strings" "sync" "sync/atomic" + "time" ) const ( @@ -100,7 +101,11 @@ func (g *Group) Start() error { for i, m := range g.modules { m.mgr.Info("starting") - err := m.module.Start(m.mgr) + startTime := time.Now() + + err := m.mgr.Do(m.mgr.name+" Start", func(_ *WorkerCtx) error { + return m.module.Start(m.mgr) + }) if err != nil { if !g.stopFrom(i) { g.state.Store(groupStateInvalid) @@ -109,7 +114,8 @@ func (g *Group) Start() error { } return fmt.Errorf("failed to start %s: %w", makeModuleName(m.module), err) } - m.mgr.Info("started") + duration := time.Since(startTime) + m.mgr.Info("started " + duration.String()) } g.state.Store(groupStateRunning) return nil @@ -134,7 +140,10 @@ func (g *Group) stopFrom(index int) (ok bool) { ok = true for i := index; i >= 0; i-- { m := g.modules[i] - err := m.module.Stop(m.mgr) + + err := m.mgr.Do(m.mgr.name+" Stop", func(_ *WorkerCtx) error { + return m.module.Stop(m.mgr) + }) if err != nil { m.mgr.Error("failed to stop", "err", err) ok = false diff --git a/service/mgr/worker.go b/service/mgr/worker.go index 802af72a2..22d553927 100644 --- a/service/mgr/worker.go +++ b/service/mgr/worker.go @@ -229,7 +229,7 @@ func (m *Manager) Do(name string, fn func(w *WorkerCtx) error) error { return nil case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): - // A canceled context or dexceeded eadline also means that the worker is finished. + // A canceled context or exceeded deadline also means that the worker is finished. return err default: diff --git a/service/network/clean.go b/service/network/clean.go index 962fef332..61767dc2c 100644 --- a/service/network/clean.go +++ b/service/network/clean.go @@ -53,7 +53,8 @@ func connectionCleaner(ctx *mgr.WorkerCtx) error { func cleanConnections() (activePIDs map[int]struct{}) { activePIDs = make(map[int]struct{}) - module.mgr.Go("clean connections", func(ctx *mgr.WorkerCtx) error { + // FIXME(vladimir): This was previously a MicroTask but it does not seem right, to run it asynchronously. Is'nt activePIDs going to be used after the function is called? + _ = module.mgr.Do("clean connections", func(ctx *mgr.WorkerCtx) error { now := time.Now().UTC() nowUnix := now.Unix() ignoreNewer := nowUnix - 2 diff --git a/spn/captain/module.go b/spn/captain/module.go index 4c82c1fc2..ca2b84875 100644 --- a/spn/captain/module.go +++ b/spn/captain/module.go @@ -128,75 +128,77 @@ func start() error { ships.EnableMasking(maskingBytes) // Initialize intel. - if err := registerIntelUpdateHook(); err != nil { - return err - } - if err := updateSPNIntel(module.mgr.Ctx(), nil); err != nil { - log.Errorf("spn/captain: failed to update SPN intel: %s", err) - } - - // Initialize identity and piers. - if conf.PublicHub() { - // Load identity. - if err := loadPublicIdentity(); err != nil { - // We cannot recover from this, set controlled failure (do not retry). - module.shutdownFunc(controlledFailureExitCode) - + module.mgr.Go("start", func(wc *mgr.WorkerCtx) error { + if err := registerIntelUpdateHook(); err != nil { return err } + if err := updateSPNIntel(module.mgr.Ctx(), nil); err != nil { + log.Errorf("spn/captain: failed to update SPN intel: %s", err) + } + + // Initialize identity and piers. + if conf.PublicHub() { + // Load identity. + if err := loadPublicIdentity(); err != nil { + // We cannot recover from this, set controlled failure (do not retry). + module.shutdownFunc(controlledFailureExitCode) - // Check if any networks are configured. - if !conf.HubHasIPv4() && !conf.HubHasIPv6() { - // We cannot recover from this, set controlled failure (do not retry). - module.shutdownFunc(controlledFailureExitCode) + return err + } + + // Check if any networks are configured. + if !conf.HubHasIPv4() && !conf.HubHasIPv6() { + // We cannot recover from this, set controlled failure (do not retry). + module.shutdownFunc(controlledFailureExitCode) + + return errors.New("no IP addresses for Hub configured (or detected)") + } - return errors.New("no IP addresses for Hub configured (or detected)") + // Start management of identity and piers. + if err := prepPublicIdentityMgmt(); err != nil { + return err + } + // Set ID to display on http info page. + ships.DisplayHubID = publicIdentity.ID + // Start listeners. + if err := startPiers(); err != nil { + return err + } + + // Enable connect operation. + crew.EnableConnecting(publicIdentity.Hub) } - // Start management of identity and piers. - if err := prepPublicIdentityMgmt(); err != nil { + // Subscribe to updates of cranes. + startDockHooks() + + // bootstrapping + if err := processBootstrapHubFlag(); err != nil { return err } - // Set ID to display on http info page. - ships.DisplayHubID = publicIdentity.ID - // Start listeners. - if err := startPiers(); err != nil { + if err := processBootstrapFileFlag(); err != nil { return err } - // Enable connect operation. - crew.EnableConnecting(publicIdentity.Hub) - } - - // Subscribe to updates of cranes. - startDockHooks() - - // bootstrapping - if err := processBootstrapHubFlag(); err != nil { - return err - } - if err := processBootstrapFileFlag(); err != nil { - return err - } - - // network optimizer - if conf.PublicHub() { - module.mgr.Delay("optimize network delay", 15*time.Second, optimizeNetwork).Repeat(1 * time.Minute) - } - - // client + home hub manager - if conf.Client() { - module.mgr.Go("client manager", clientManager) + // network optimizer + if conf.PublicHub() { + module.mgr.Delay("optimize network delay", 15*time.Second, optimizeNetwork).Repeat(1 * time.Minute) + } - // Reset failing hubs when the network changes while not connected. - module.instance.NetEnv().EventNetworkChange.AddCallback("reset failing hubs", func(_ *mgr.WorkerCtx, _ struct{}) (bool, error) { - if ready.IsNotSet() { - navigator.Main.ResetFailingStates(module.mgr.Ctx()) - } - return false, nil - }) - } + // client + home hub manager + if conf.Client() { + module.mgr.Go("client manager", clientManager) + // Reset failing hubs when the network changes while not connected. + module.instance.NetEnv().EventNetworkChange.AddCallback("reset failing hubs", func(_ *mgr.WorkerCtx, _ struct{}) (bool, error) { + if ready.IsNotSet() { + navigator.Main.ResetFailingStates(module.mgr.Ctx()) + } + return false, nil + }) + } + return nil + }) return nil } diff --git a/spn/navigator/module.go b/spn/navigator/module.go index 41a2cc703..b763b893d 100644 --- a/spn/navigator/module.go +++ b/spn/navigator/module.go @@ -80,49 +80,52 @@ func start() error { return err } - // Wait for geoip databases to be ready. - // Try again if not yet ready, as this is critical. - // The "wait" parameter times out after 1 second. - // Allow 30 seconds for both databases to load. -geoInitCheck: - for i := 0; i < 30; i++ { - switch { - case !geoip.IsInitialized(false, true): // First, IPv4. - case !geoip.IsInitialized(true, true): // Then, IPv6. - default: - break geoInitCheck + module.mgr.Go("initializing hubs", func(wc *mgr.WorkerCtx) error { + // Wait for geoip databases to be ready. + // Try again if not yet ready, as this is critical. + // The "wait" parameter times out after 1 second. + // Allow 30 seconds for both databases to load. + geoInitCheck: + for i := 0; i < 30; i++ { + switch { + case !geoip.IsInitialized(false, true): // First, IPv4. + case !geoip.IsInitialized(true, true): // Then, IPv6. + default: + break geoInitCheck + } } - } - err = Main.InitializeFromDatabase() - if err != nil { - // Wait for three seconds, then try again. - time.Sleep(3 * time.Second) err = Main.InitializeFromDatabase() if err != nil { - // Even if the init fails, we can try to start without it and get data along the way. - log.Warningf("spn/navigator: %s", err) + // Wait for three seconds, then try again. + time.Sleep(3 * time.Second) + err = Main.InitializeFromDatabase() + if err != nil { + // Even if the init fails, we can try to start without it and get data along the way. + log.Warningf("spn/navigator: %s", err) + } + } + err = Main.RegisterHubUpdateHook() + if err != nil { + return err } - } - err = Main.RegisterHubUpdateHook() - if err != nil { - return err - } - // TODO: delete superseded hubs after x amount of time - _ = module.mgr.Delay("update states", 3*time.Minute, Main.updateStates).Repeat(1 * time.Hour) - _ = module.mgr.Delay("update failing states", 3*time.Minute, Main.updateFailingStates).Repeat(1 * time.Minute) + // TODO: delete superseded hubs after x amount of time + _ = module.mgr.Delay("update states", 3*time.Minute, Main.updateStates).Repeat(1 * time.Hour) + _ = module.mgr.Delay("update failing states", 3*time.Minute, Main.updateFailingStates).Repeat(1 * time.Minute) - if conf.PublicHub() { - // Only measure Hubs on public Hubs. - module.mgr.Delay("measure hubs", 5*time.Minute, Main.measureHubs).Repeat(1 * time.Minute) + if conf.PublicHub() { + // Only measure Hubs on public Hubs. + module.mgr.Delay("measure hubs", 5*time.Minute, Main.measureHubs).Repeat(1 * time.Minute) - // Only register metrics on Hubs, as they only make sense there. - err := registerMetrics() - if err != nil { - return err + // Only register metrics on Hubs, as they only make sense there. + err := registerMetrics() + if err != nil { + return err + } } - } + return nil + }) return nil } From 594e53defd2ac48365efc50c3c9c593afea9f859 Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov Date: Wed, 3 Jul 2024 13:51:06 +0300 Subject: [PATCH 19/56] [WIP] Fix minor issues --- service/firewall/module.go | 12 ++++++------ service/instance.go | 26 +++++++++++++++++++------- service/sync/settings.go | 2 +- service/updates/module.go | 5 ++++- 4 files changed, 30 insertions(+), 15 deletions(-) diff --git a/service/firewall/module.go b/service/firewall/module.go index ec5e4fb7b..99af76d88 100644 --- a/service/firewall/module.go +++ b/service/firewall/module.go @@ -32,7 +32,7 @@ func (ss *stringSliceFlag) Set(value string) error { // module *modules.Module var allowedClients stringSliceFlag -type Filter struct { +type Firewall struct { mgr *mgr.Manager instance instance @@ -42,7 +42,7 @@ func init() { flag.Var(&allowedClients, "allowed-clients", "A list of binaries that are allowed to connect to the Portmaster API") } -func (f *Filter) Start(mgr *mgr.Manager) error { +func (f *Firewall) Start(mgr *mgr.Manager) error { f.mgr = mgr if err := prep(); err != nil { @@ -53,7 +53,7 @@ func (f *Filter) Start(mgr *mgr.Manager) error { return start() } -func (f *Filter) Stop(mgr *mgr.Manager) error { +func (f *Firewall) Stop(mgr *mgr.Manager) error { return stop() } @@ -127,12 +127,12 @@ func stop() error { } var ( - module *Filter + module *Firewall shimLoaded atomic.Bool ) -func New(instance instance) (*Filter, error) { - module = &Filter{ +func New(instance instance) (*Firewall, error) { + module = &Firewall{ instance: instance, } diff --git a/service/instance.go b/service/instance.go index b582ec2ee..a896b960e 100644 --- a/service/instance.go +++ b/service/instance.go @@ -17,6 +17,7 @@ import ( "github.com/safing/portmaster/service/firewall" "github.com/safing/portmaster/service/firewall/interception" "github.com/safing/portmaster/service/intel/customlists" + "github.com/safing/portmaster/service/intel/filterlists" "github.com/safing/portmaster/service/intel/geoip" "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/nameserver" @@ -77,7 +78,8 @@ type Instance struct { profile *profile.ProfileModule network *network.Network netquery *netquery.NetQuery - filter *firewall.Filter + firewall *firewall.Firewall + filterLists *filterlists.FilterLists interception *interception.Interception customlist *customlists.CustomList status *status.Status @@ -101,7 +103,7 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { // Base modules instance.database, err = dbmodule.New(instance) if err != nil { - return nil, fmt.Errorf("create config module: %w", err) + return nil, fmt.Errorf("create database module: %w", err) } instance.config, err = config.New(instance) if err != nil { @@ -209,9 +211,13 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { if err != nil { return nil, fmt.Errorf("create netquery module: %w", err) } - instance.filter, err = firewall.New(instance) + instance.firewall, err = firewall.New(instance) + if err != nil { + return nil, fmt.Errorf("create firewall module: %w", err) + } + instance.filterLists, err = filterlists.New(instance) if err != nil { - return nil, fmt.Errorf("create filter module: %w", err) + return nil, fmt.Errorf("create filterLists module: %w", err) } instance.interception, err = interception.New(instance) if err != nil { @@ -281,7 +287,8 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { instance.profile, instance.network, instance.netquery, - instance.filter, + instance.firewall, + instance.filterLists, instance.interception, instance.customlist, instance.status, @@ -426,8 +433,13 @@ func (i *Instance) Profile() *profile.ProfileModule { } // Firewall returns the firewall module. -func (i *Instance) Firewall() *firewall.Filter { - return i.filter +func (i *Instance) Firewall() *firewall.Firewall { + return i.firewall +} + +// FilterLists returns the filterLists module. +func (i *Instance) FilterLists() *filterlists.FilterLists { + return i.filterLists } // Interception returns the interception module. diff --git a/service/sync/settings.go b/service/sync/settings.go index 3a7dd8e0e..372f63db2 100644 --- a/service/sync/settings.go +++ b/service/sync/settings.go @@ -154,7 +154,7 @@ func handleImportSettings(ar *api.Request) (any, error) { } request.Export = export case request.Export != nil: - // Export is aleady parsed. + // Export is already parsed. default: return nil, ErrInvalidImportRequest } diff --git a/service/updates/module.go b/service/updates/module.go index 9a8554c39..7736b4023 100644 --- a/service/updates/module.go +++ b/service/updates/module.go @@ -30,9 +30,12 @@ func (u *Updates) Start(m *mgr.Manager) error { u.mgr = m u.EventResourcesUpdated = mgr.NewEventMgr[struct{}](ResourceUpdateEvent, u.mgr) u.EventVersionsUpdated = mgr.NewEventMgr[struct{}](VersionUpdateEvent, u.mgr) - u.States = mgr.NewStateMgr(u.mgr) + if err := prep(); err != nil { + return err + } + return start() } From 8c3109ad3f6d9fc9107bf7812a32b3c309de96ae Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov Date: Wed, 3 Jul 2024 13:51:58 +0300 Subject: [PATCH 20/56] [WIP] Fix missing subsystem in settings --- .../src/app/shared/config/config-settings.ts | 25 +- .../src/app/shared/config/subsystems.ts | 433 ++++++++++++++++++ 2 files changed, 439 insertions(+), 19 deletions(-) create mode 100644 desktop/angular/src/app/shared/config/subsystems.ts diff --git a/desktop/angular/src/app/shared/config/config-settings.ts b/desktop/angular/src/app/shared/config/config-settings.ts index 49301abf1..513e6c4b9 100644 --- a/desktop/angular/src/app/shared/config/config-settings.ts +++ b/desktop/angular/src/app/shared/config/config-settings.ts @@ -24,7 +24,7 @@ import { } from '@safing/portmaster-api'; import { BehaviorSubject, Subscription, combineLatest } from 'rxjs'; import { debounceTime } from 'rxjs/operators'; -import { StatusService, Subsystem } from 'src/app/services'; +import { StatusService } from 'src/app/services'; import { fadeInAnimation, fadeInListAnimation, @@ -44,6 +44,8 @@ import { ImportDialogComponent, } from './import-dialog/import-dialog.component'; +import { subsystems, SubsystemWithExpertise } from './subsystems' + interface Category { name: string; settings: Setting[]; @@ -52,12 +54,6 @@ interface Category { hasUserDefinedValues: boolean; } -interface SubsystemWithExpertise extends Subsystem { - minimumExpertise: ExpertiseLevelNumber; - isDisabled: boolean; - hasUserDefinedValues: boolean; -} - @Component({ selector: 'app-settings-view', templateUrl: './config-settings.html', @@ -66,7 +62,7 @@ interface SubsystemWithExpertise extends Subsystem { }) export class ConfigSettingsViewComponent implements OnInit, OnDestroy, AfterViewInit { - subsystems: SubsystemWithExpertise[] = []; + subsystems: SubsystemWithExpertise[] = subsystems; others: Setting[] | null = null; settings: Map = new Map(); @@ -207,7 +203,7 @@ export class ConfigSettingsViewComponent private searchService: FuzzySearchService, private actionIndicator: ActionIndicatorService, private portapi: PortapiService, - private dialog: SfngDialogService + private dialog: SfngDialogService, ) { } openImportDialog() { @@ -303,21 +299,12 @@ export class ConfigSettingsViewComponent ngOnInit(): void { this.subscription = combineLatest([ this.onSettingsChange, - this.statusService.querySubsystem(), this.onSearch.pipe(debounceTime(250)), this.configService.watch('core/releaseLevel'), ]) .pipe(debounceTime(10)) .subscribe( - ([settings, subsystems, searchTerm, currentReleaseLevelSetting]) => { - this.subsystems = subsystems.map((s) => ({ - ...s, - // we start with developer and decrease to the lowest number required - // while grouping the settings. - minimumExpertise: ExpertiseLevelNumber.developer, - isDisabled: false, - hasUserDefinedValues: false, - })); + ([settings, searchTerm, currentReleaseLevelSetting]) => { this.others = []; this.settings = new Map(); diff --git a/desktop/angular/src/app/shared/config/subsystems.ts b/desktop/angular/src/app/shared/config/subsystems.ts new file mode 100644 index 000000000..2071c8074 --- /dev/null +++ b/desktop/angular/src/app/shared/config/subsystems.ts @@ -0,0 +1,433 @@ +import { ExpertiseLevelNumber } from "@safing/portmaster-api"; +import { ModuleStatus, Subsystem } from "src/app/services/status.types"; + +export interface SubsystemWithExpertise extends Subsystem { + minimumExpertise: ExpertiseLevelNumber; + isDisabled: boolean; + hasUserDefinedValues: boolean; +} + +export var subsystems : SubsystemWithExpertise[] = [ + { + minimumExpertise: ExpertiseLevelNumber.developer, + isDisabled: false, + hasUserDefinedValues: false, + ID: "core", + Name: "Core", + Description: "Base Structure and System Integration", + Modules: [ + { + Name: "core", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "subsystems", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "runtime", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "status", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "ui", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "compat", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "broadcasts", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "sync", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + } + ], + FailureStatus: 0, + ToggleOptionKey: "", + ExpertiseLevel: "user", + ReleaseLevel: 0, + ConfigKeySpace: "config:core/", + _meta: { + Created: 0, + Modified: 0, + Expires: 0, + Deleted: 0, + Key: "runtime:subsystems/core" + } + }, + { + minimumExpertise: ExpertiseLevelNumber.developer, + isDisabled: false, + hasUserDefinedValues: false, + ID: "dns", + Name: "Secure DNS", + Description: "DNS resolver with scoping and DNS-over-TLS", + Modules: [ + { + Name: "nameserver", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "resolver", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + } + ], + FailureStatus: 0, + ToggleOptionKey: "", + ExpertiseLevel: "user", + ReleaseLevel: 0, + ConfigKeySpace: "config:dns/", + _meta: { + Created: 0, + Modified: 0, + Expires: 0, + Deleted: 0, + Key: "runtime:subsystems/dns" + } + }, + { + minimumExpertise: ExpertiseLevelNumber.developer, + isDisabled: false, + hasUserDefinedValues: false, + ID: "filter", + Name: "Privacy Filter", + Description: "DNS and Network Filter", + Modules: [ + { + Name: "filter", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "interception", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "base", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "database", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "config", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "rng", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "metrics", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "api", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "updates", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "network", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "netenv", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "processes", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "profiles", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "notifications", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "intel", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "geoip", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "filterlists", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "customlists", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + } + ], + FailureStatus: 0, + ToggleOptionKey: "", + ExpertiseLevel: "user", + ReleaseLevel: 0, + ConfigKeySpace: "config:filter/", + _meta: { + Created: 0, + Modified: 0, + Expires: 0, + Deleted: 0, + Key: "runtime:subsystems/filter" + } + }, + { + minimumExpertise: ExpertiseLevelNumber.developer, + isDisabled: false, + hasUserDefinedValues: false, + ID: "history", + Name: "Network History", + Description: "Keep Network History Data", + Modules: [ + { + Name: "netquery", + Enabled: true, + Status: ModuleStatus.Operational, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + } + ], + FailureStatus: 0, + ToggleOptionKey: "", + ExpertiseLevel: "user", + ReleaseLevel: 0, + ConfigKeySpace: "config:history/", + _meta: { + Created: 0, + Modified: 0, + Expires: 0, + Deleted: 0, + Key: "runtime:subsystems/history" + } + }, + { + minimumExpertise: ExpertiseLevelNumber.developer, + isDisabled: false, + hasUserDefinedValues: false, + ID: "spn", + Name: "SPN", + Description: "Safing Privacy Network", + Modules: [ + { + Name: "captain", + Enabled: false, + Status: 2, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "terminal", + Enabled: false, + Status: 2, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "cabin", + Enabled: false, + Status: 2, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "ships", + Enabled: false, + Status: 2, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "docks", + Enabled: false, + Status: 2, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "access", + Enabled: false, + Status: 2, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "crew", + Enabled: false, + Status: 2, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "navigator", + Enabled: false, + Status: 2, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "sluice", + Enabled: false, + Status: 2, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + }, + { + Name: "patrol", + Enabled: false, + Status: 2, + FailureStatus: 0, + FailureID: "", + FailureMsg: "" + } + ], + FailureStatus: 0, + ToggleOptionKey: "spn/enable", + ExpertiseLevel: "user", + ReleaseLevel: 0, + ConfigKeySpace: "config:spn/", + _meta: { + Created: 0, + Modified: 0, + Expires: 0, + Deleted: 0, + Key: "runtime:subsystems/spn" + } + } +]; From 286a0d36cb22236b897bf1711c686752897340b1 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 5 Jul 2024 14:20:17 +0200 Subject: [PATCH 21/56] [WIP] Initialize managers in constructor --- service/mgr/module.go | 5 ++-- service/mgr/states.go | 12 +++++--- service/mgr/worker.go | 2 +- service/updates/module.go | 62 ++++++++++++++++++++++----------------- 4 files changed, 47 insertions(+), 34 deletions(-) diff --git a/service/mgr/module.go b/service/mgr/module.go index e900eb2a8..8f2867c37 100644 --- a/service/mgr/module.go +++ b/service/mgr/module.go @@ -54,8 +54,9 @@ type groupModule struct { // Module is an manage-able instance of some component. type Module interface { - Start(mgr *Manager) error - Stop(mgr *Manager) error + Manager() *Manager + Start() error + Stop() error } // NewGroup returns a new group of modules. diff --git a/service/mgr/states.go b/service/mgr/states.go index c7fc6d653..e6f049219 100644 --- a/service/mgr/states.go +++ b/service/mgr/states.go @@ -43,9 +43,7 @@ type StateUpdate struct { States []State } -// NewStateMgr returns a new event manager. -// It is easiest used as a public field on a struct, -// so that others can simply Subscribe() oder AddCallback(). +// NewStateMgr returns a new state manager. func NewStateMgr(mgr *Manager) *StateMgr { return &StateMgr{ statesEventMgr: NewEventMgr[StateUpdate]("state update", mgr), @@ -53,6 +51,11 @@ func NewStateMgr(mgr *Manager) *StateMgr { } } +// NewStateMgr returns a new state manager. +func (m *Manager) NewStateMgr() *StateMgr { + return NewStateMgr(m) +} + // Add adds a state. // If a state with the same ID already exists, it is replaced. func (m *StateMgr) Add(s State) { @@ -98,6 +101,7 @@ func (m *StateMgr) Clear() { m.statesEventMgr.Submit(m.export()) } +// Export returns the current states. func (m *StateMgr) Export() StateUpdate { m.statesLock.Lock() defer m.statesLock.Unlock() @@ -105,7 +109,7 @@ func (m *StateMgr) Export() StateUpdate { return m.export() } -// Export returns the current states. +// export returns the current states. func (m *StateMgr) export() StateUpdate { name := "" if m.mgr != nil { diff --git a/service/mgr/worker.go b/service/mgr/worker.go index 22d553927..9f7eb2ee0 100644 --- a/service/mgr/worker.go +++ b/service/mgr/worker.go @@ -53,7 +53,7 @@ func (w *WorkerCtx) Cancel() { w.cancelCtx() } -// Scheduler returns the scheduler the worker was started from. +// WorkerMgr returns the worker manager the worker was started from. // Returns nil if the worker is not associated with a scheduler. func (w *WorkerCtx) WorkerMgr() *WorkerMgr { return w.workerMgr diff --git a/service/updates/module.go b/service/updates/module.go index 7736b4023..87ade51a4 100644 --- a/service/updates/module.go +++ b/service/updates/module.go @@ -11,10 +11,8 @@ import ( // Updates provides access to released artifacts. type Updates struct { - mgr *mgr.Manager - - instance instance - shutdownFunc func(exitCode int) + m *mgr.Manager + states *mgr.StateMgr updateWorkerMgr *mgr.WorkerMgr restartWorkerMgr *mgr.WorkerMgr @@ -22,16 +20,43 @@ type Updates struct { EventResourcesUpdated *mgr.EventMgr[struct{}] EventVersionsUpdated *mgr.EventMgr[struct{}] - States *mgr.StateMgr + instance instance + shutdownFunc func(exitCode int) +} + +var ( + module *Updates + shimLoaded atomic.Bool +) + +// New returns a new UI module. +func New(instance instance, shutdownFunc func(exitCode int)) (*Updates, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + m := mgr.New("updates") + module = &Updates{ + m: m, + states: m.NewStateMgr(), + updateWorkerMgr: m.NewWorkerMgr("updater", checkForUpdates, nil), //FIXME + restartWorkerMgr: m.NewWorkerMgr("updater", checkForUpdates, nil), //FIXME + EventResourcesUpdated: mgr.NewEventMgr[struct{}](ResourceUpdateEvent, m), + EventVersionsUpdated: mgr.NewEventMgr[struct{}](VersionUpdateEvent, m), + instance: instance, + shutdownFunc: shutdownFunc, + } + + return module, nil +} + +// Manager returns the module manager. +func (u *Updates) Manager() *mgr.Manager { + return u.m } // Start starts the module. func (u *Updates) Start(m *mgr.Manager) error { - u.mgr = m - u.EventResourcesUpdated = mgr.NewEventMgr[struct{}](ResourceUpdateEvent, u.mgr) - u.EventVersionsUpdated = mgr.NewEventMgr[struct{}](VersionUpdateEvent, u.mgr) - u.States = mgr.NewStateMgr(u.mgr) - if err := prep(); err != nil { return err } @@ -44,23 +69,6 @@ func (u *Updates) Stop(_ *mgr.Manager) error { return stop() } -var ( - module *Updates - shimLoaded atomic.Bool -) - -// New returns a new UI module. -func New(instance instance, shutdownFunc func(exitCode int)) (*Updates, error) { - if shimLoaded.CompareAndSwap(false, true) { - module = &Updates{ - instance: instance, - shutdownFunc: shutdownFunc, - } - return module, nil - } - return nil, errors.New("only one instance allowed") -} - type instance interface { API() *api.API Config() *config.Config From 6d835961d08dbe46b68a38251753f334d45335d6 Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov Date: Mon, 8 Jul 2024 14:50:42 +0300 Subject: [PATCH 22/56] [WIP] Move module event initialization to constrictors --- base/api/module.go | 12 ++++++++---- base/config/module.go | 17 ++++++++++------- base/database/dbmodule/db.go | 12 ++++++++---- base/metrics/module.go | 11 +++++++---- base/notifications/module.go | 14 ++++++++++---- base/rng/rng.go | 20 ++++++++++++-------- base/runtime/module.go | 15 ++++++++++++--- service/broadcasts/module.go | 12 ++++++++---- service/compat/module.go | 12 ++++++++---- service/core/base/module.go | 12 ++++++++---- service/core/core.go | 15 +++++++++++---- service/firewall/interception/module.go | 12 ++++++++---- service/firewall/module.go | 15 ++++++++++++--- service/intel/customlists/module.go | 17 ++++++++++------- service/intel/filterlists/module.go | 14 +++++++++----- service/intel/geoip/module.go | 11 ++++++++--- service/mgr/module.go | 4 ++-- service/nameserver/module.go | 15 ++++++++++----- service/netenv/main.go | 20 +++++++++++++------- service/netquery/module_api.go | 12 ++++++++---- service/network/module.go | 15 +++++++++------ service/process/module.go | 12 +++++++++--- service/profile/module.go | 23 +++++++++++++---------- service/resolver/main.go | 14 +++++++++----- service/status/module.go | 12 +++++++++--- service/sync/module.go | 12 +++++++++--- service/ui/module.go | 23 ++++++++++++++--------- service/updates/config.go | 2 +- service/updates/main.go | 6 ++---- service/updates/module.go | 20 +++++++++++++------- service/updates/notify.go | 2 +- spn/access/module.go | 15 ++++++++++----- spn/cabin/module.go | 11 +++++++++-- spn/captain/module.go | 18 +++++++++++------- spn/crew/module.go | 12 ++++++++---- spn/docks/module.go | 12 ++++++++---- spn/navigator/module.go | 12 ++++++++---- spn/patrol/module.go | 16 ++++++++++------ spn/ships/module.go | 12 ++++++++---- spn/sluice/module.go | 12 ++++++++---- spn/terminal/module.go | 12 ++++++++---- 41 files changed, 359 insertions(+), 186 deletions(-) diff --git a/base/api/module.go b/base/api/module.go index 01ea043ac..d75d29f22 100644 --- a/base/api/module.go +++ b/base/api/module.go @@ -16,9 +16,12 @@ type API struct { online bool } +func (api *API) Manager() *mgr.Manager { + return api.mgr +} + // Start starts the module. -func (api *API) Start(m *mgr.Manager) error { - api.mgr = m +func (api *API) Start() error { if err := prep(); err != nil { return err } @@ -31,7 +34,7 @@ func (api *API) Start(m *mgr.Manager) error { } // Stop stops the module. -func (api *API) Stop(_ *mgr.Manager) error { +func (api *API) Stop() error { return stop() } @@ -45,8 +48,9 @@ func New(instance instance) (*API, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("API") module = &API{ + mgr: m, instance: instance, } diff --git a/base/config/module.go b/base/config/module.go index e3bc96f4f..a83660f69 100644 --- a/base/config/module.go +++ b/base/config/module.go @@ -16,11 +16,12 @@ type Config struct { EventConfigChange *mgr.EventMgr[struct{}] } -// Start starts the module. -func (u *Config) Start(m *mgr.Manager) error { - u.mgr = m - u.EventConfigChange = mgr.NewEventMgr[struct{}](ChangeEvent, u.mgr) +func (u *Config) Manager() *mgr.Manager { + return u.mgr +} +// Start starts the module. +func (u *Config) Start() error { if err := prep(); err != nil { return err } @@ -28,7 +29,7 @@ func (u *Config) Start(m *mgr.Manager) error { } // Stop stops the module. -func (u *Config) Stop(_ *mgr.Manager) error { +func (u *Config) Stop() error { return nil } @@ -42,9 +43,11 @@ func New(instance instance) (*Config, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("Config") module = &Config{ - instance: instance, + mgr: m, + instance: instance, + EventConfigChange: mgr.NewEventMgr[struct{}](ChangeEvent, m), } return module, nil } diff --git a/base/database/dbmodule/db.go b/base/database/dbmodule/db.go index 77af7958a..c29207167 100644 --- a/base/database/dbmodule/db.go +++ b/base/database/dbmodule/db.go @@ -15,12 +15,15 @@ type DBModule struct { instance instance } -func (dbm *DBModule) Start(m *mgr.Manager) error { - module.mgr = m +func (dbm *DBModule) Manager() *mgr.Manager { + return dbm.mgr +} + +func (dbm *DBModule) Start() error { return start() } -func (dbm *DBModule) Stop(m *mgr.Manager) error { +func (dbm *DBModule) Stop() error { return stop() } @@ -69,8 +72,9 @@ func New(instance instance) (*DBModule, error) { if err := prep(); err != nil { return nil, err } - + m := mgr.New("DBModule") module = &DBModule{ + mgr: m, instance: instance, } diff --git a/base/metrics/module.go b/base/metrics/module.go index a950c56e1..d312de3f4 100644 --- a/base/metrics/module.go +++ b/base/metrics/module.go @@ -17,16 +17,18 @@ type Metrics struct { metricTicker *mgr.SleepyTicker } -func (met *Metrics) Start(m *mgr.Manager) error { - met.mgr = m +func (met *Metrics) Manager() *mgr.Manager { + return met.mgr +} +func (met *Metrics) Start() error { if err := prepConfig(); err != nil { return err } return start() } -func (met *Metrics) Stop(m *mgr.Manager) error { +func (met *Metrics) Stop() error { return stop() } @@ -195,8 +197,9 @@ func New(instance instance) (*Metrics, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("Metrics") module = &Metrics{ + mgr: m, instance: instance, } diff --git a/base/notifications/module.go b/base/notifications/module.go index 455376944..a3b7d835c 100644 --- a/base/notifications/module.go +++ b/base/notifications/module.go @@ -16,8 +16,11 @@ type Notifications struct { States *mgr.StateMgr } -func (n *Notifications) Start(m *mgr.Manager) error { - n.mgr = m +func (n *Notifications) Manager() *mgr.Manager { + return n.mgr +} + +func (n *Notifications) Start() error { n.States = mgr.NewStateMgr(n.mgr) if err := prep(); err != nil { @@ -27,7 +30,7 @@ func (n *Notifications) Start(m *mgr.Manager) error { return start() } -func (n *Notifications) Stop(m *mgr.Manager) error { +func (n *Notifications) Stop() error { return nil } @@ -92,9 +95,12 @@ func New(instance instance) (*Notifications, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("Notifications") module = &Notifications{ + mgr: m, instance: instance, + + States: mgr.NewStateMgr(m), } return module, nil diff --git a/base/rng/rng.go b/base/rng/rng.go index 8c73ebe8c..01ad5dd1a 100644 --- a/base/rng/rng.go +++ b/base/rng/rng.go @@ -41,8 +41,11 @@ func newCipher(key []byte) (cipher.Block, error) { } } -func (r *Rng) Start(m *mgr.Manager) error { - r.mgr = m +func (r *Rng) Manager() *mgr.Manager { + return r.mgr +} + +func (r *Rng) Start() error { rngLock.Lock() defer rngLock.Unlock() @@ -52,7 +55,7 @@ func (r *Rng) Start(m *mgr.Manager) error { } // add another (async) OS rng seed - m.Go("initial rng feed", func(_ *mgr.WorkerCtx) error { + r.mgr.Go("initial rng feed", func(_ *mgr.WorkerCtx) error { // get entropy from OS osEntropy := make([]byte, minFeedEntropy/8) _, err := rand.Read(osEntropy) @@ -70,18 +73,18 @@ func (r *Rng) Start(m *mgr.Manager) error { rngReady = true // random source: OS - m.Go("os rng feeder", osFeeder) + r.mgr.Go("os rng feeder", osFeeder) // random source: goroutine ticks - m.Go("tick rng feeder", tickFeeder) + r.mgr.Go("tick rng feeder", tickFeeder) // full feeder - m.Go("full feeder", fullFeeder) + r.mgr.Go("full feeder", fullFeeder) return nil } -func (r *Rng) Stop(m *mgr.Manager) error { +func (r *Rng) Stop() error { return nil } @@ -94,8 +97,9 @@ func New(instance instance) (*Rng, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("Rng") module = &Rng{ + mgr: m, instance: instance, } diff --git a/base/runtime/module.go b/base/runtime/module.go index 6ef8dec87..ac46d02aa 100644 --- a/base/runtime/module.go +++ b/base/runtime/module.go @@ -14,10 +14,15 @@ import ( var DefaultRegistry = NewRegistry() type Runtime struct { + mgr *mgr.Manager instance instance } -func (r *Runtime) Start(m *mgr.Manager) error { +func (r *Runtime) Manager() *mgr.Manager { + return r.mgr +} + +func (r *Runtime) Start() error { _, err := database.Register(&database.Database{ Name: "runtime", Description: "Runtime database", @@ -39,7 +44,7 @@ func (r *Runtime) Start(m *mgr.Manager) error { return nil } -func (r *Runtime) Stop(m *mgr.Manager) error { +func (r *Runtime) Stop() error { return nil } @@ -59,7 +64,11 @@ func New(instance instance) (*Runtime, error) { return nil, errors.New("only one instance allowed") } - module = &Runtime{instance: instance} + m := mgr.New("Runtime") + module = &Runtime{ + mgr: m, + instance: instance, + } return module, nil } diff --git a/service/broadcasts/module.go b/service/broadcasts/module.go index 1f6d3882b..bb162afdd 100644 --- a/service/broadcasts/module.go +++ b/service/broadcasts/module.go @@ -15,15 +15,18 @@ type Broadcasts struct { instance instance } -func (b *Broadcasts) Start(m *mgr.Manager) error { - b.mgr = m +func (b *Broadcasts) Manager() *mgr.Manager { + return b.mgr +} + +func (b *Broadcasts) Start() error { if err := prep(); err != nil { return err } return start() } -func (b *Broadcasts) Stop(m *mgr.Manager) error { +func (b *Broadcasts) Stop() error { return nil } @@ -71,8 +74,9 @@ func New(instance instance) (*Broadcasts, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("Broadcasts") module = &Broadcasts{ + mgr: m, instance: instance, } return module, nil diff --git a/service/compat/module.go b/service/compat/module.go index 1f25d5271..9cf6791bd 100644 --- a/service/compat/module.go +++ b/service/compat/module.go @@ -20,9 +20,12 @@ type Compat struct { selfcheckWorkerMgr *mgr.WorkerMgr } +func (u *Compat) Manager() *mgr.Manager { + return u.mgr +} + // Start starts the module. -func (u *Compat) Start(m *mgr.Manager) error { - u.mgr = m +func (u *Compat) Start() error { if err := prep(); err != nil { return err } @@ -30,7 +33,7 @@ func (u *Compat) Start(m *mgr.Manager) error { } // Stop stops the module. -func (u *Compat) Stop(_ *mgr.Manager) error { +func (u *Compat) Stop() error { return stop() } @@ -153,8 +156,9 @@ func New(instance instance) (*Compat, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("Compat") module = &Compat{ + mgr: m, instance: instance, } return module, nil diff --git a/service/core/base/module.go b/service/core/base/module.go index 165c3823c..57df37322 100644 --- a/service/core/base/module.go +++ b/service/core/base/module.go @@ -15,8 +15,11 @@ type Base struct { instance instance } -func (b *Base) Start(m *mgr.Manager) error { - b.mgr = m +func (b *Base) Manager() *mgr.Manager { + return b.mgr +} + +func (b *Base) Start() error { startProfiling() if err := registerDatabases(); err != nil { @@ -28,7 +31,7 @@ func (b *Base) Start(m *mgr.Manager) error { return nil } -func (b *Base) Stop(m *mgr.Manager) error { +func (b *Base) Stop() error { return nil } @@ -42,8 +45,9 @@ func New(instance instance) (*Base, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("Base") module = &Base{ + mgr: m, instance: instance, } return module, nil diff --git a/service/core/core.go b/service/core/core.go index c1554e4bf..d6812d6b1 100644 --- a/service/core/core.go +++ b/service/core/core.go @@ -25,16 +25,18 @@ const ( ) type Core struct { + m *mgr.Manager instance instance EventShutdown *mgr.EventMgr[struct{}] EventRestart *mgr.EventMgr[struct{}] } -func (c *Core) Start(m *mgr.Manager) error { - c.EventShutdown = mgr.NewEventMgr[struct{}]("shutdown", m) - c.EventRestart = mgr.NewEventMgr[struct{}]("restart", m) +func (c *Core) Manager() *mgr.Manager { + return c.m +} +func (c *Core) Start() error { if err := prep(); err != nil { return err } @@ -42,7 +44,7 @@ func (c *Core) Start(m *mgr.Manager) error { return start() } -func (c *Core) Stop(m *mgr.Manager) error { +func (c *Core) Stop() error { return nil } @@ -123,8 +125,13 @@ func New(instance instance) (*Core, error) { return nil, errors.New("only one instance allowed") } + m := mgr.New("Core") module = &Core{ + m: m, instance: instance, + + EventShutdown: mgr.NewEventMgr[struct{}]("shutdown", m), + EventRestart: mgr.NewEventMgr[struct{}]("restart", m), } return module, nil } diff --git a/service/firewall/interception/module.go b/service/firewall/interception/module.go index 76964f1af..da0a727fd 100644 --- a/service/firewall/interception/module.go +++ b/service/firewall/interception/module.go @@ -15,12 +15,15 @@ type Interception struct { instance instance } -func (i *Interception) Start(m *mgr.Manager) error { - i.mgr = m +func (i *Interception) Manager() *mgr.Manager { + return i.mgr +} + +func (i *Interception) Start() error { return start() } -func (i *Interception) Stop(m *mgr.Manager) error { +func (i *Interception) Stop() error { return stop() } @@ -84,8 +87,9 @@ func New(instance instance) (*Interception, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("Interception") module = &Interception{ + mgr: m, instance: instance, } return module, nil diff --git a/service/firewall/module.go b/service/firewall/module.go index 99af76d88..c3cc59a45 100644 --- a/service/firewall/module.go +++ b/service/firewall/module.go @@ -1,6 +1,7 @@ package firewall import ( + "errors" "flag" "fmt" "path/filepath" @@ -42,9 +43,11 @@ func init() { flag.Var(&allowedClients, "allowed-clients", "A list of binaries that are allowed to connect to the Portmaster API") } -func (f *Firewall) Start(mgr *mgr.Manager) error { - f.mgr = mgr +func (f *Firewall) Manager() *mgr.Manager { + return f.mgr +} +func (f *Firewall) Start() error { if err := prep(); err != nil { log.Errorf("Failed to prepare firewall module %q", err) return err @@ -53,7 +56,7 @@ func (f *Firewall) Start(mgr *mgr.Manager) error { return start() } -func (f *Firewall) Stop(mgr *mgr.Manager) error { +func (f *Firewall) Stop() error { return stop() } @@ -132,7 +135,13 @@ var ( ) func New(instance instance) (*Firewall, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + m := mgr.New("Firewall") module = &Firewall{ + mgr: m, instance: instance, } diff --git a/service/intel/customlists/module.go b/service/intel/customlists/module.go index 7401ed810..45b0eb2fe 100644 --- a/service/intel/customlists/module.go +++ b/service/intel/customlists/module.go @@ -26,19 +26,18 @@ type CustomList struct { States *mgr.StateMgr } -func (cl *CustomList) Start(m *mgr.Manager) error { - cl.mgr = m - cl.States = mgr.NewStateMgr(m) - - cl.updateFilterListWorkerMgr = m.NewWorkerMgr("update custom filter list", checkAndUpdateFilterList, nil) +func (cl *CustomList) Manager() *mgr.Manager { + return cl.mgr +} +func (cl *CustomList) Start() error { if err := prep(); err != nil { return err } return start() } -func (cl *CustomList) Stop(m *mgr.Manager) error { +func (cl *CustomList) Stop() error { return nil } @@ -219,9 +218,13 @@ func New(instance instance) (*CustomList, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("CustomList") module = &CustomList{ + mgr: m, instance: instance, + + States: mgr.NewStateMgr(m), + updateFilterListWorkerMgr: m.NewWorkerMgr("update custom filter list", checkAndUpdateFilterList, nil), } return module, nil } diff --git a/service/intel/filterlists/module.go b/service/intel/filterlists/module.go index 376402f90..c499ef209 100644 --- a/service/intel/filterlists/module.go +++ b/service/intel/filterlists/module.go @@ -26,17 +26,18 @@ type FilterLists struct { States *mgr.StateMgr } -func (fl *FilterLists) Start(m *mgr.Manager) error { - fl.mgr = m - fl.States = mgr.NewStateMgr(m) +func (fl *FilterLists) Manager() *mgr.Manager { + return fl.mgr +} +func (fl *FilterLists) Start() error { if err := prep(); err != nil { return err } return start() } -func (fl *FilterLists) Stop(m *mgr.Manager) error { +func (fl *FilterLists) Stop() error { return stop() } @@ -127,9 +128,12 @@ func New(instance instance) (*FilterLists, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("FilterLists") module = &FilterLists{ + mgr: m, instance: instance, + + States: mgr.NewStateMgr(m), } return module, nil } diff --git a/service/intel/geoip/module.go b/service/intel/geoip/module.go index 326352ecc..e8527e4ee 100644 --- a/service/intel/geoip/module.go +++ b/service/intel/geoip/module.go @@ -14,8 +14,11 @@ type GeoIP struct { instance instance } -func (g *GeoIP) Start(m *mgr.Manager) error { - g.mgr = m +func (g *GeoIP) Manager() *mgr.Manager { + return g.mgr +} + +func (g *GeoIP) Start() error { if err := api.RegisterEndpoint(api.Endpoint{ Path: "intel/geoip/countries", Read: api.PermitUser, @@ -38,7 +41,7 @@ func (g *GeoIP) Start(m *mgr.Manager) error { return nil } -func (g *GeoIP) Stop(m *mgr.Manager) error { +func (g *GeoIP) Stop() error { return nil } @@ -53,7 +56,9 @@ func New(instance instance) (*GeoIP, error) { return nil, errors.New("only one instance allowed") } + m := mgr.New("geoip") module = &GeoIP{ + mgr: m, instance: instance, } return module, nil diff --git a/service/mgr/module.go b/service/mgr/module.go index 8f2867c37..93de1efa1 100644 --- a/service/mgr/module.go +++ b/service/mgr/module.go @@ -105,7 +105,7 @@ func (g *Group) Start() error { startTime := time.Now() err := m.mgr.Do(m.mgr.name+" Start", func(_ *WorkerCtx) error { - return m.module.Start(m.mgr) + return m.module.Start() }) if err != nil { if !g.stopFrom(i) { @@ -143,7 +143,7 @@ func (g *Group) stopFrom(index int) (ok bool) { m := g.modules[i] err := m.mgr.Do(m.mgr.name+" Stop", func(_ *WorkerCtx) error { - return m.module.Stop(m.mgr) + return m.module.Stop() }) if err != nil { m.mgr.Error("failed to stop", "err", err) diff --git a/service/nameserver/module.go b/service/nameserver/module.go index 1107c85d3..d062c8fad 100644 --- a/service/nameserver/module.go +++ b/service/nameserver/module.go @@ -26,16 +26,18 @@ type NameServer struct { States *mgr.StateMgr } -func (ns *NameServer) Start(m *mgr.Manager) error { - ns.mgr = m - ns.States = mgr.NewStateMgr(m) +func (ns *NameServer) Manager() *mgr.Manager { + return ns.mgr +} + +func (ns *NameServer) Start() error { if err := prep(); err != nil { return err } return start() } -func (ns *NameServer) Stop(m *mgr.Manager) error { +func (ns *NameServer) Stop() error { return stop() } @@ -316,9 +318,12 @@ func New(instance instance) (*NameServer, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("NameServer") module = &NameServer{ + mgr: m, instance: instance, + + States: mgr.NewStateMgr(m), } return module, nil } diff --git a/service/netenv/main.go b/service/netenv/main.go index f6dcef9d7..44848a5fd 100644 --- a/service/netenv/main.go +++ b/service/netenv/main.go @@ -18,26 +18,28 @@ const ( ) type NetEnv struct { + m *mgr.Manager instance instance EventNetworkChange *mgr.EventMgr[struct{}] EventOnlineStatusChange *mgr.EventMgr[OnlineStatus] } -func (ne *NetEnv) Start(m *mgr.Manager) error { - ne.EventNetworkChange = mgr.NewEventMgr[struct{}]("network change", m) - ne.EventOnlineStatusChange = mgr.NewEventMgr[OnlineStatus]("online status change", m) +func (ne *NetEnv) Manager() *mgr.Manager { + return ne.m +} +func (ne *NetEnv) Start() error { if err := prep(); err != nil { return err } - m.Go( + ne.m.Go( "monitor network changes", monitorNetworkChanges, ) - m.Go( + ne.m.Go( "monitor online status", monitorOnlineStatus, ) @@ -45,7 +47,7 @@ func (ne *NetEnv) Start(m *mgr.Manager) error { return nil } -func (ne *NetEnv) Stop(m *mgr.Manager) error { +func (ne *NetEnv) Stop() error { return nil } @@ -92,9 +94,13 @@ func New(instance instance) (*NetEnv, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("NetEnv") module = &NetEnv{ + m: m, instance: instance, + + EventNetworkChange: mgr.NewEventMgr[struct{}]("network change", m), + EventOnlineStatusChange: mgr.NewEventMgr[OnlineStatus]("online status change", m), } return module, nil } diff --git a/service/netquery/module_api.go b/service/netquery/module_api.go index c0d6a0a1a..518d260a7 100644 --- a/service/netquery/module_api.go +++ b/service/netquery/module_api.go @@ -199,8 +199,11 @@ func (nq *NetQuery) prepare() error { return nil } -func (nq *NetQuery) Start(m *mgr.Manager) error { - nq.mgr = m +func (nq *NetQuery) Manager() *mgr.Manager { + return nq.mgr +} + +func (nq *NetQuery) Start() error { if err := nq.prepare(); err != nil { return fmt.Errorf("failed to prepare netquery module: %w", err) } @@ -284,7 +287,7 @@ func (nq *NetQuery) Start(m *mgr.Manager) error { return nil } -func (nq *NetQuery) Stop(m *mgr.Manager) error { +func (nq *NetQuery) Stop() error { // we don't use m.Module.Ctx here because it is already cancelled when stop is called. // just give the clean up 1 minute to happen and abort otherwise. ctx, cancel := context.WithTimeout(context.Background(), time.Minute) @@ -318,8 +321,9 @@ func NewModule(instance instance) (*NetQuery, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("NetQuery") module = &NetQuery{ + mgr: m, instance: instance, } return module, nil diff --git a/service/network/module.go b/service/network/module.go index bff26e561..396e81533 100644 --- a/service/network/module.go +++ b/service/network/module.go @@ -30,17 +30,18 @@ type Network struct { EventConnectionReattributed *mgr.EventMgr[string] } -func (n *Network) Start(m *mgr.Manager) error { - n.mgr = m - n.EventConnectionReattributed = mgr.NewEventMgr[string](ConnectionReattributedEvent, m) +func (n *Network) Manager() *mgr.Manager { + return n.mgr +} +func (n *Network) Start() error { if err := prep(); err != nil { return err } return start() } -func (n *Network) Stop(mgr *mgr.Manager) error { +func (n *Network) Stop() error { return nil } @@ -174,9 +175,11 @@ func New(instance instance) (*Network, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("Network") module = &Network{ - instance: instance, + mgr: m, + instance: instance, + EventConnectionReattributed: mgr.NewEventMgr[string](ConnectionReattributedEvent, m), } return module, nil } diff --git a/service/process/module.go b/service/process/module.go index ddc2808e1..eb6693931 100644 --- a/service/process/module.go +++ b/service/process/module.go @@ -10,10 +10,15 @@ import ( ) type ProcessModule struct { + mgr *mgr.Manager instance instance } -func (pm *ProcessModule) Start(m *mgr.Manager) error { +func (pm *ProcessModule) Manager() *mgr.Manager { + return pm.mgr +} + +func (pm *ProcessModule) Start() error { if err := prep(); err != nil { return err } @@ -21,7 +26,7 @@ func (pm *ProcessModule) Start(m *mgr.Manager) error { return start() } -func (pm *ProcessModule) Stop(m *mgr.Manager) error { +func (pm *ProcessModule) Stop() error { return nil } @@ -58,8 +63,9 @@ func New(instance instance) (*ProcessModule, error) { if err := prep(); err != nil { return nil, err } - + m := mgr.New("ProcessModule") module = &ProcessModule{ + mgr: m, instance: instance, } return module, nil diff --git a/service/profile/module.go b/service/profile/module.go index 0233351ab..81e506a51 100644 --- a/service/profile/module.go +++ b/service/profile/module.go @@ -40,15 +40,11 @@ type ProfileModule struct { States *mgr.StateMgr } -func (pm *ProfileModule) Start(m *mgr.Manager) error { - pm.mgr = m - - pm.EventConfigChange = mgr.NewEventMgr[string](ConfigChangeEvent, m) - pm.EventDelete = mgr.NewEventMgr[string](DeletedEvent, m) - pm.EventMigrated = mgr.NewEventMgr[[]string](MigratedEvent, m) - - pm.States = mgr.NewStateMgr(m) +func (pm *ProfileModule) Manager() *mgr.Manager { + return pm.mgr +} +func (pm *ProfileModule) Start() error { if err := prep(); err != nil { return err } @@ -56,7 +52,7 @@ func (pm *ProfileModule) Start(m *mgr.Manager) error { return start() } -func (pm *ProfileModule) Stop(m *mgr.Manager) error { +func (pm *ProfileModule) Stop() error { return stop() } @@ -143,9 +139,16 @@ func NewModule(instance instance) (*ProfileModule, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("Profile") module = &ProfileModule{ + mgr: m, instance: instance, + + EventConfigChange: mgr.NewEventMgr[string](ConfigChangeEvent, m), + EventDelete: mgr.NewEventMgr[string](DeletedEvent, m), + EventMigrated: mgr.NewEventMgr[[]string](MigratedEvent, m), + + States: mgr.NewStateMgr(m), } return module, nil diff --git a/service/resolver/main.go b/service/resolver/main.go index a15486111..69e4c055e 100644 --- a/service/resolver/main.go +++ b/service/resolver/main.go @@ -31,17 +31,18 @@ type ResolverModule struct { States *mgr.StateMgr } -func (rm *ResolverModule) Start(m *mgr.Manager) error { - rm.mgr = m - rm.States = mgr.NewStateMgr(m) +func (rm *ResolverModule) Manager() *mgr.Manager { + return rm.mgr +} +func (rm *ResolverModule) Start() error { if err := prep(); err != nil { return err } return start() } -func (rm *ResolverModule) Stop(m *mgr.Manager) error { +func (rm *ResolverModule) Stop() error { return nil } @@ -257,9 +258,12 @@ func New(instance instance) (*ResolverModule, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("Resolver") module = &ResolverModule{ + mgr: m, instance: instance, + + States: mgr.NewStateMgr(m), } return module, nil } diff --git a/service/status/module.go b/service/status/module.go index bd414ffb6..946bd41cd 100644 --- a/service/status/module.go +++ b/service/status/module.go @@ -11,10 +11,15 @@ import ( ) type Status struct { + mgr *mgr.Manager instance instance } -func (s *Status) Start(m *mgr.Manager) error { +func (s *Status) Manager() *mgr.Manager { + return s.mgr +} + +func (s *Status) Start() error { if err := setupRuntimeProvider(); err != nil { return err } @@ -29,7 +34,7 @@ func (s *Status) Start(m *mgr.Manager) error { return nil } -func (s *Status) Stop(m *mgr.Manager) error { +func (s *Status) Stop() error { return nil } @@ -52,8 +57,9 @@ func New(instance instance) (*Status, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("Status") module = &Status{ + mgr: m, instance: instance, } diff --git a/service/sync/module.go b/service/sync/module.go index 3d7a8a8ed..c6a8c9d1f 100644 --- a/service/sync/module.go +++ b/service/sync/module.go @@ -9,14 +9,19 @@ import ( ) type Sync struct { + mgr *mgr.Manager instance instance } -func (s *Sync) Start(m *mgr.Manager) error { +func (s *Sync) Manager() *mgr.Manager { + return s.mgr +} + +func (s *Sync) Start() error { return prep() } -func (s *Sync) Stop(m *mgr.Manager) error { +func (s *Sync) Stop() error { return nil } @@ -48,8 +53,9 @@ func New(instance instance) (*Sync, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("Sync") module = &Sync{ + mgr: m, instance: instance, } return module, nil diff --git a/service/ui/module.go b/service/ui/module.go index 9c74c4d17..09639ba22 100644 --- a/service/ui/module.go +++ b/service/ui/module.go @@ -42,10 +42,12 @@ type UI struct { instance instance } -// Start starts the module. -func (ui *UI) Start(m *mgr.Manager) error { - ui.mgr = m +func (ui *UI) Manager() *mgr.Manager { + return ui.mgr +} +// Start starts the module. +func (ui *UI) Start() error { if err := prep(); err != nil { return err } @@ -54,7 +56,7 @@ func (ui *UI) Start(m *mgr.Manager) error { } // Stop stops the module. -func (ui *UI) Stop(_ *mgr.Manager) error { +func (ui *UI) Stop() error { return nil } @@ -62,12 +64,15 @@ var shimLoaded atomic.Bool // New returns a new UI module. func New(instance instance) (*UI, error) { - if shimLoaded.CompareAndSwap(false, true) { - return &UI{ - instance: instance, - }, nil + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + m := mgr.New("UI") + module := &UI{ + mgr: m, + instance: instance, } - return nil, errors.New("only one instance allowed") + return module, nil } type instance interface { diff --git a/service/updates/config.go b/service/updates/config.go index d51f14187..f765fd4c0 100644 --- a/service/updates/config.go +++ b/service/updates/config.go @@ -164,7 +164,7 @@ func updateRegistryConfig(_ *mgr.WorkerCtx, _ struct{}) (cancel bool, err error) module.EventVersionsUpdated.Submit(struct{}{}) if softwareUpdatesCurrentlyEnabled || intelUpdatesCurrentlyEnabled { - module.States.Clear() + module.states.Clear() if err := TriggerUpdate(true, false); err != nil { log.Warningf("updates: failed to trigger update: %s", err) } diff --git a/service/updates/main.go b/service/updates/main.go index fbdaa9dd7..09e25f965 100644 --- a/service/updates/main.go +++ b/service/updates/main.go @@ -106,7 +106,7 @@ func prep() error { func start() error { initConfig() - module.restartWorkerMgr = module.mgr.Repeat("automatic restart", 10*time.Minute, automaticRestart) + module.restartWorkerMgr.Repeat(10 * time.Minute) module.instance.Config().EventConfigChange.AddCallback("update registry config", updateRegistryConfig) // create registry @@ -164,7 +164,7 @@ func start() error { log.Warningf("updates: %s", warning) } - err = registry.LoadIndexes(module.mgr.Ctx()) + err = registry.LoadIndexes(module.m.Ctx()) if err != nil { log.Warningf("updates: failed to load indexes: %s", err) } @@ -184,8 +184,6 @@ func start() error { } // start updater task - module.updateWorkerMgr = module.mgr.NewWorkerMgr("updater", checkForUpdates, nil) - if !disableTaskSchedule { _ = module.updateWorkerMgr.Repeat(30 * time.Minute) } diff --git a/service/updates/module.go b/service/updates/module.go index 87ade51a4..c20ef051f 100644 --- a/service/updates/module.go +++ b/service/updates/module.go @@ -35,12 +35,13 @@ func New(instance instance, shutdownFunc func(exitCode int)) (*Updates, error) { return nil, errors.New("only one instance allowed") } - m := mgr.New("updates") + m := mgr.New("Updates") module = &Updates{ - m: m, - states: m.NewStateMgr(), - updateWorkerMgr: m.NewWorkerMgr("updater", checkForUpdates, nil), //FIXME - restartWorkerMgr: m.NewWorkerMgr("updater", checkForUpdates, nil), //FIXME + m: m, + states: m.NewStateMgr(), + + updateWorkerMgr: m.NewWorkerMgr("updater", checkForUpdates, nil), + restartWorkerMgr: m.NewWorkerMgr("automatic restart", automaticRestart, nil), EventResourcesUpdated: mgr.NewEventMgr[struct{}](ResourceUpdateEvent, m), EventVersionsUpdated: mgr.NewEventMgr[struct{}](VersionUpdateEvent, m), instance: instance, @@ -50,13 +51,18 @@ func New(instance instance, shutdownFunc func(exitCode int)) (*Updates, error) { return module, nil } +// State returns the state manager. +func (u *Updates) State() *mgr.StateMgr { + return u.states +} + // Manager returns the module manager. func (u *Updates) Manager() *mgr.Manager { return u.m } // Start starts the module. -func (u *Updates) Start(m *mgr.Manager) error { +func (u *Updates) Start() error { if err := prep(); err != nil { return err } @@ -65,7 +71,7 @@ func (u *Updates) Start(m *mgr.Manager) error { } // Stop stops the module. -func (u *Updates) Stop(_ *mgr.Manager) error { +func (u *Updates) Stop() error { return stop() } diff --git a/service/updates/notify.go b/service/updates/notify.go index b20bd92fe..0cd97bfde 100644 --- a/service/updates/notify.go +++ b/service/updates/notify.go @@ -23,7 +23,7 @@ var updateFailedCnt = new(atomic.Int32) func notifyUpdateSuccess(force bool) { updateFailedCnt.Store(0) - module.States.Clear() + module.states.Clear() updateState := registry.GetState().Updates flavor := updateSuccess diff --git a/spn/access/module.go b/spn/access/module.go index 95ab8bcce..a7ad74919 100644 --- a/spn/access/module.go +++ b/spn/access/module.go @@ -25,11 +25,11 @@ type Access struct { EventAccountUpdate *mgr.EventMgr[struct{}] } -func (a *Access) Start(m *mgr.Manager) error { - a.mgr = m - a.EventAccountUpdate = mgr.NewEventMgr[struct{}](AccountUpdateEvent, m) - a.updateAccountWorkerMgr = m.NewWorkerMgr("update account", UpdateAccount, nil) +func (a *Access) Manager() *mgr.Manager { + return a.mgr +} +func (a *Access) Start() error { if err := prep(); err != nil { return err } @@ -37,7 +37,7 @@ func (a *Access) Start(m *mgr.Manager) error { return start() } -func (a *Access) Stop(m *mgr.Manager) error { +func (a *Access) Stop() error { return stop() } @@ -216,8 +216,13 @@ func New(instance instance) (*Access, error) { return nil, errors.New("only one instance allowed") } + m := mgr.New("Access") module = &Access{ + mgr: m, instance: instance, + + EventAccountUpdate: mgr.NewEventMgr[struct{}](AccountUpdateEvent, m), + updateAccountWorkerMgr: m.NewWorkerMgr("update account", UpdateAccount, nil), } return module, nil } diff --git a/spn/cabin/module.go b/spn/cabin/module.go index 8e35b6128..8d6a48974 100644 --- a/spn/cabin/module.go +++ b/spn/cabin/module.go @@ -9,14 +9,19 @@ import ( ) type Cabin struct { + m *mgr.Manager instance instance } -func (c *Cabin) Start(m *mgr.Manager) error { +func (c *Cabin) Manager() *mgr.Manager { + return c.m +} + +func (c *Cabin) Start() error { return prep() } -func (c *Cabin) Stop(m *mgr.Manager) error { +func (c *Cabin) Stop() error { return nil } @@ -49,7 +54,9 @@ func New(instance instance) (*Cabin, error) { return nil, err } + m := mgr.New("Cabin") module = &Cabin{ + m: m, instance: instance, } return module, nil diff --git a/spn/captain/module.go b/spn/captain/module.go index ca2b84875..c526b5ea2 100644 --- a/spn/captain/module.go +++ b/spn/captain/module.go @@ -42,12 +42,11 @@ type Captain struct { EventSPNConnected *mgr.EventMgr[struct{}] } -func (c *Captain) Start(m *mgr.Manager) error { - c.mgr = m - c.States = mgr.NewStateMgr(m) - c.EventSPNConnected = mgr.NewEventMgr[struct{}](SPNConnectedEvent, m) - c.maintainPublicStatus = m.NewWorkerMgr("maintain public status", maintainPublicStatus, nil) +func (c *Captain) Manager() *mgr.Manager { + return c.mgr +} +func (c *Captain) Start() error { if err := prep(); err != nil { return err } @@ -55,7 +54,7 @@ func (c *Captain) Start(m *mgr.Manager) error { return start() } -func (c *Captain) Stop(m *mgr.Manager) error { +func (c *Captain) Stop() error { return stop() } @@ -250,10 +249,15 @@ func New(instance instance, shutdownFunc func(exitCode int)) (*Captain, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("Captain") module = &Captain{ + mgr: m, instance: instance, shutdownFunc: shutdownFunc, + + States: mgr.NewStateMgr(m), + EventSPNConnected: mgr.NewEventMgr[struct{}](SPNConnectedEvent, m), + maintainPublicStatus: m.NewWorkerMgr("maintain public status", maintainPublicStatus, nil), } return module, nil } diff --git a/spn/crew/module.go b/spn/crew/module.go index 54ab7051b..cc384c58e 100644 --- a/spn/crew/module.go +++ b/spn/crew/module.go @@ -14,12 +14,15 @@ type Crew struct { instance instance } -func (c *Crew) Start(m *mgr.Manager) error { - c.mgr = m +func (c *Crew) Manager() *mgr.Manager { + return c.mgr +} + +func (c *Crew) Start() error { return start() } -func (c *Crew) Stop(m *mgr.Manager) error { +func (c *Crew) Stop() error { return stop() } @@ -61,8 +64,9 @@ func New(instance instance) (*Crew, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("Crew") module = &Crew{ + mgr: m, instance: instance, } return module, nil diff --git a/spn/docks/module.go b/spn/docks/module.go index 2c2f17b97..4ee6242f4 100644 --- a/spn/docks/module.go +++ b/spn/docks/module.go @@ -17,12 +17,15 @@ type Docks struct { instance instance } -func (d *Docks) Start(m *mgr.Manager) error { - d.mgr = m +func (d *Docks) Manager() *mgr.Manager { + return d.mgr +} + +func (d *Docks) Start() error { return start() } -func (d *Docks) Stop(m *mgr.Manager) error { +func (d *Docks) Stop() error { return stopAllCranes() } @@ -135,8 +138,9 @@ func New(instance instance) (*Docks, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("Docks") module = &Docks{ + mgr: m, instance: instance, } return module, nil diff --git a/spn/navigator/module.go b/spn/navigator/module.go index b763b893d..4392b98dc 100644 --- a/spn/navigator/module.go +++ b/spn/navigator/module.go @@ -40,8 +40,11 @@ type Navigator struct { instance instance } -func (n *Navigator) Start(m *mgr.Manager) error { - n.mgr = m +func (n *Navigator) Manager() *mgr.Manager { + return n.mgr +} + +func (n *Navigator) Start() error { if err := prep(); err != nil { return err } @@ -49,7 +52,7 @@ func (n *Navigator) Start(m *mgr.Manager) error { return start() } -func (n *Navigator) Stop(m *mgr.Manager) error { +func (n *Navigator) Stop() error { return stop() } @@ -145,8 +148,9 @@ func New(instance instance) (*Navigator, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("Navigator") module = &Navigator{ + mgr: m, instance: instance, } return module, nil diff --git a/spn/patrol/module.go b/spn/patrol/module.go index 414c06165..3c3ce769b 100644 --- a/spn/patrol/module.go +++ b/spn/patrol/module.go @@ -19,17 +19,18 @@ type Patrol struct { EventChangeSignal *mgr.EventMgr[struct{}] } -func (p *Patrol) Start(m *mgr.Manager) error { - p.mgr = m - p.EventChangeSignal = mgr.NewEventMgr[struct{}](ChangeSignalEventName, m) +func (p *Patrol) Manager() *mgr.Manager { + return p.mgr +} +func (p *Patrol) Start() error { if conf.PublicHub() { - m.Repeat("connectivity test", 5*time.Minute, connectivityCheckTask) + p.mgr.Repeat("connectivity test", 5*time.Minute, connectivityCheckTask) } return nil } -func (p *Patrol) Stop(m *mgr.Manager) error { +func (p *Patrol) Stop() error { return nil } @@ -43,9 +44,12 @@ func New(instance instance) (*Patrol, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("Patrol") module = &Patrol{ + mgr: m, instance: instance, + + EventChangeSignal: mgr.NewEventMgr[struct{}](ChangeSignalEventName, m), } return module, nil } diff --git a/spn/ships/module.go b/spn/ships/module.go index 1bdd8c95b..57b10657b 100644 --- a/spn/ships/module.go +++ b/spn/ships/module.go @@ -13,8 +13,11 @@ type Ships struct { instance instance } -func (s *Ships) Start(m *mgr.Manager) error { - s.mgr = m +func (s *Ships) Manager() *mgr.Manager { + return s.mgr +} + +func (s *Ships) Start() error { if conf.PublicHub() { initPageInput() } @@ -22,7 +25,7 @@ func (s *Ships) Start(m *mgr.Manager) error { return nil } -func (s *Ships) Stop(m *mgr.Manager) error { +func (s *Ships) Stop() error { return nil } @@ -36,8 +39,9 @@ func New(instance instance) (*Ships, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("Ships") module = &Ships{ + mgr: m, instance: instance, } return module, nil diff --git a/spn/sluice/module.go b/spn/sluice/module.go index 9c2c4cbf5..99197dea5 100644 --- a/spn/sluice/module.go +++ b/spn/sluice/module.go @@ -15,12 +15,15 @@ type SluiceModule struct { instance instance } -func (s *SluiceModule) Start(m *mgr.Manager) error { - s.mgr = m +func (s *SluiceModule) Manager() *mgr.Manager { + return s.mgr +} + +func (s *SluiceModule) Start() error { return start() } -func (s *SluiceModule) Stop(_ *mgr.Manager) error { +func (s *SluiceModule) Stop() error { return stop() } @@ -66,8 +69,9 @@ func New(instance instance) (*SluiceModule, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("SluiceModule") module = &SluiceModule{ + mgr: m, instance: instance, } return module, nil diff --git a/spn/terminal/module.go b/spn/terminal/module.go index d032ccd03..46424a5ac 100644 --- a/spn/terminal/module.go +++ b/spn/terminal/module.go @@ -17,12 +17,15 @@ type TerminalModule struct { instance instance } -func (s *TerminalModule) Start(m *mgr.Manager) error { - s.mgr = m +func (s *TerminalModule) Manager() *mgr.Manager { + return s.mgr +} + +func (s *TerminalModule) Start() error { return start() } -func (s *TerminalModule) Stop(m *mgr.Manager) error { +func (s *TerminalModule) Stop() error { return nil } @@ -102,8 +105,9 @@ func New(instance instance) (*TerminalModule, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } - + m := mgr.New("TerminalModule") module = &TerminalModule{ + mgr: m, instance: instance, } return module, nil From e9ae5835ad12da1a6dfb837946f85111d7ec9d93 Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov Date: Tue, 9 Jul 2024 16:39:31 +0300 Subject: [PATCH 23/56] [WIP] Fix setting for enabling and disabling the SPN module --- base/api/module.go | 6 +- base/config/module.go | 8 +- base/metrics/module.go | 6 +- service/firewall/module.go | 8 +- service/instance.go | 152 +++++++++++++++------------- service/intel/customlists/module.go | 12 +-- service/nameserver/module.go | 7 +- service/process/module.go | 8 +- service/profile/module.go | 8 +- service/updates/main.go | 4 - service/updates/module.go | 4 + spn/access/module.go | 14 ++- spn/cabin/module.go | 11 +- spn/captain/config.go | 18 +++- spn/captain/module.go | 10 +- 15 files changed, 159 insertions(+), 117 deletions(-) diff --git a/base/api/module.go b/base/api/module.go index d75d29f22..1f912ba42 100644 --- a/base/api/module.go +++ b/base/api/module.go @@ -22,9 +22,6 @@ func (api *API) Manager() *mgr.Manager { // Start starts the module. func (api *API) Start() error { - if err := prep(); err != nil { - return err - } if err := start(); err != nil { return err } @@ -54,6 +51,9 @@ func New(instance instance) (*API, error) { instance: instance, } + if err := prep(); err != nil { + return nil, err + } return module, nil } diff --git a/base/config/module.go b/base/config/module.go index a83660f69..e0d68fdba 100644 --- a/base/config/module.go +++ b/base/config/module.go @@ -22,9 +22,6 @@ func (u *Config) Manager() *mgr.Manager { // Start starts the module. func (u *Config) Start() error { - if err := prep(); err != nil { - return err - } return start() } @@ -49,6 +46,11 @@ func New(instance instance) (*Config, error) { instance: instance, EventConfigChange: mgr.NewEventMgr[struct{}](ChangeEvent, m), } + + if err := prep(); err != nil { + return nil, err + } + return module, nil } diff --git a/base/metrics/module.go b/base/metrics/module.go index d312de3f4..d77a982f3 100644 --- a/base/metrics/module.go +++ b/base/metrics/module.go @@ -22,9 +22,6 @@ func (met *Metrics) Manager() *mgr.Manager { } func (met *Metrics) Start() error { - if err := prepConfig(); err != nil { - return err - } return start() } @@ -202,6 +199,9 @@ func New(instance instance) (*Metrics, error) { mgr: m, instance: instance, } + if err := prepConfig(); err != nil { + return nil, err + } return module, nil } diff --git a/service/firewall/module.go b/service/firewall/module.go index c3cc59a45..2be348fb3 100644 --- a/service/firewall/module.go +++ b/service/firewall/module.go @@ -103,10 +103,6 @@ func prep() error { return false, err }) - if err := registerConfig(); err != nil { - return err - } - return nil } @@ -149,6 +145,10 @@ func New(instance instance) (*Firewall, error) { return nil, err } + if err := registerConfig(); err != nil { + return nil, err + } + return module, nil } diff --git a/service/instance.go b/service/instance.go index a896b960e..6291dd7c7 100644 --- a/service/instance.go +++ b/service/instance.go @@ -58,22 +58,10 @@ type Instance struct { rng *rng.Rng base *base.Base - core *core.Core - updates *updates.Updates - geoip *geoip.GeoIP - netenv *netenv.NetEnv - - access *access.Access - cabin *cabin.Cabin - navigator *navigator.Navigator - captain *captain.Captain - crew *crew.Crew - docks *docks.Docks - patrol *patrol.Patrol - ships *ships.Ships - sluice *sluice.SluiceModule - terminal *terminal.TerminalModule - + core *core.Core + updates *updates.Updates + geoip *geoip.GeoIP + netenv *netenv.NetEnv ui *ui.UI profile *profile.ProfileModule network *network.Network @@ -89,6 +77,20 @@ type Instance struct { process *process.ProcessModule resolver *resolver.ResolverModule sync *sync.Sync + + access *access.Access + + // SPN modules + SpnGroup *mgr.Group + cabin *cabin.Cabin + navigator *navigator.Navigator + captain *captain.Captain + crew *crew.Crew + docks *docks.Docks + patrol *patrol.Patrol + ships *ships.Ships + sluice *sluice.SluiceModule + terminal *terminal.TerminalModule } // New returns a new portmaster service instance. @@ -134,7 +136,7 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { return nil, fmt.Errorf("create base module: %w", err) } - // Global service modules + // Service modules instance.core, err = core.New(instance) if err != nil { return nil, fmt.Errorf("create core module: %w", err) @@ -151,50 +153,6 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { if err != nil { return nil, fmt.Errorf("create netenv module: %w", err) } - - // SPN modules - instance.access, err = access.New(instance) - if err != nil { - return nil, fmt.Errorf("create access module: %w", err) - } - instance.cabin, err = cabin.New(instance) - if err != nil { - return nil, fmt.Errorf("create cabin module: %w", err) - } - instance.navigator, err = navigator.New(instance) - if err != nil { - return nil, fmt.Errorf("create navigator module: %w", err) - } - instance.captain, err = captain.New(instance, svcCfg.ShutdownFunc) - if err != nil { - return nil, fmt.Errorf("create captain module: %w", err) - } - instance.crew, err = crew.New(instance) - if err != nil { - return nil, fmt.Errorf("create crew module: %w", err) - } - instance.docks, err = docks.New(instance) - if err != nil { - return nil, fmt.Errorf("create docks module: %w", err) - } - instance.patrol, err = patrol.New(instance) - if err != nil { - return nil, fmt.Errorf("create patrol module: %w", err) - } - instance.ships, err = ships.New(instance) - if err != nil { - return nil, fmt.Errorf("create ships module: %w", err) - } - instance.sluice, err = sluice.New(instance) - if err != nil { - return nil, fmt.Errorf("create sluice module: %w", err) - } - instance.terminal, err = terminal.New(instance) - if err != nil { - return nil, fmt.Errorf("create terminal module: %w", err) - } - - // Service modules instance.ui, err = ui.New(instance) if err != nil { return nil, fmt.Errorf("create ui module: %w", err) @@ -255,6 +213,48 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { if err != nil { return nil, fmt.Errorf("create sync module: %w", err) } + instance.access, err = access.New(instance) + if err != nil { + return nil, fmt.Errorf("create access module: %w", err) + } + + // SPN modules + instance.cabin, err = cabin.New(instance) + if err != nil { + return nil, fmt.Errorf("create cabin module: %w", err) + } + instance.navigator, err = navigator.New(instance) + if err != nil { + return nil, fmt.Errorf("create navigator module: %w", err) + } + instance.captain, err = captain.New(instance, svcCfg.ShutdownFunc) + if err != nil { + return nil, fmt.Errorf("create captain module: %w", err) + } + instance.crew, err = crew.New(instance) + if err != nil { + return nil, fmt.Errorf("create crew module: %w", err) + } + instance.docks, err = docks.New(instance) + if err != nil { + return nil, fmt.Errorf("create docks module: %w", err) + } + instance.patrol, err = patrol.New(instance) + if err != nil { + return nil, fmt.Errorf("create patrol module: %w", err) + } + instance.ships, err = ships.New(instance) + if err != nil { + return nil, fmt.Errorf("create ships module: %w", err) + } + instance.sluice, err = sluice.New(instance) + if err != nil { + return nil, fmt.Errorf("create sluice module: %w", err) + } + instance.terminal, err = terminal.New(instance) + if err != nil { + return nil, fmt.Errorf("create terminal module: %w", err) + } // Add all modules to instance group. instance.Group = mgr.NewGroup( @@ -272,17 +272,6 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { instance.geoip, instance.netenv, - instance.access, - instance.cabin, - instance.navigator, - instance.captain, - instance.crew, - instance.docks, - instance.patrol, - instance.ships, - instance.sluice, - instance.terminal, - instance.ui, instance.profile, instance.network, @@ -298,6 +287,20 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { instance.process, instance.resolver, instance.sync, + instance.access, + ) + + // SPN Group + instance.SpnGroup = mgr.NewGroup( + instance.cabin, + instance.navigator, + instance.captain, + instance.crew, + instance.docks, + instance.patrol, + instance.ships, + instance.sluice, + instance.terminal, ) // FIXME: call this before to trigger shutdown/restart event @@ -502,6 +505,11 @@ func (i *Instance) Core() *core.Core { return i.core } +// SPNGroup returns the group of all SPN modules. +func (i *Instance) SPNGroup() *mgr.Group { + return i.SpnGroup +} + // Events // SPN connected func (i *Instance) GetEventSPNConnected() *mgr.EventMgr[struct{}] { diff --git a/service/intel/customlists/module.go b/service/intel/customlists/module.go index 45b0eb2fe..615a5f82d 100644 --- a/service/intel/customlists/module.go +++ b/service/intel/customlists/module.go @@ -61,12 +61,6 @@ var ( func prep() error { initFilterLists() - // Register the config in the ui. - err := registerConfig() - if err != nil { - return err - } - // Register api endpoint for updating the filter list. if err := api.RegisterEndpoint(api.Endpoint{ Path: "customlists/update", @@ -226,6 +220,12 @@ func New(instance instance) (*CustomList, error) { States: mgr.NewStateMgr(m), updateFilterListWorkerMgr: m.NewWorkerMgr("update custom filter list", checkAndUpdateFilterList, nil), } + // Register the config in the ui. + err := registerConfig() + if err != nil { + return nil, err + } + return module, nil } diff --git a/service/nameserver/module.go b/service/nameserver/module.go index d062c8fad..248b90abf 100644 --- a/service/nameserver/module.go +++ b/service/nameserver/module.go @@ -31,9 +31,6 @@ func (ns *NameServer) Manager() *mgr.Manager { } func (ns *NameServer) Start() error { - if err := prep(); err != nil { - return err - } return start() } @@ -325,6 +322,10 @@ func New(instance instance) (*NameServer, error) { States: mgr.NewStateMgr(m), } + if err := prep(); err != nil { + return nil, err + } + return module, nil } diff --git a/service/process/module.go b/service/process/module.go index eb6693931..ccab72d55 100644 --- a/service/process/module.go +++ b/service/process/module.go @@ -19,10 +19,6 @@ func (pm *ProcessModule) Manager() *mgr.Manager { } func (pm *ProcessModule) Start() error { - if err := prep(); err != nil { - return err - } - return start() } @@ -68,6 +64,10 @@ func New(instance instance) (*ProcessModule, error) { mgr: m, instance: instance, } + + if err := prep(); err != nil { + return nil, err + } return module, nil } diff --git a/service/profile/module.go b/service/profile/module.go index 81e506a51..8f08008aa 100644 --- a/service/profile/module.go +++ b/service/profile/module.go @@ -45,10 +45,6 @@ func (pm *ProfileModule) Manager() *mgr.Manager { } func (pm *ProfileModule) Start() error { - if err := prep(); err != nil { - return err - } - return start() } @@ -151,6 +147,10 @@ func NewModule(instance instance) (*ProfileModule, error) { States: mgr.NewStateMgr(m), } + if err := prep(); err != nil { + return nil, err + } + return module, nil } diff --git a/service/updates/main.go b/service/updates/main.go index 09e25f965..a2d0167b8 100644 --- a/service/updates/main.go +++ b/service/updates/main.go @@ -85,10 +85,6 @@ func init() { } func prep() error { - if err := registerConfig(); err != nil { - return err - } - // Check if update server URL supplied via flag is a valid URL. if updateServerFromFlag != "" { u, err := url.Parse(updateServerFromFlag) diff --git a/service/updates/module.go b/service/updates/module.go index c20ef051f..df8e4fdd1 100644 --- a/service/updates/module.go +++ b/service/updates/module.go @@ -48,6 +48,10 @@ func New(instance instance, shutdownFunc func(exitCode int)) (*Updates, error) { shutdownFunc: shutdownFunc, } + if err := registerConfig(); err != nil { + return nil, err + } + return module, nil } diff --git a/spn/access/module.go b/spn/access/module.go index a7ad74919..b62599c20 100644 --- a/spn/access/module.go +++ b/spn/access/module.go @@ -77,6 +77,15 @@ func prep() error { } func start() error { + module.instance.Config().EventConfigChange.AddCallback("spn enable check", func(wc *mgr.WorkerCtx, s struct{}) (bool, error) { + enabled := config.GetAsBool("spn/enable", false) + if enabled() { + return false, module.instance.SPNGroup().Start() + } else { + return false, module.instance.SPNGroup().Stop() + } + }) + // Initialize zones. if err := InitializeZones(); err != nil { return err @@ -227,4 +236,7 @@ func New(instance instance) (*Access, error) { return module, nil } -type instance interface{} +type instance interface { + Config() *config.Config + SPNGroup() *mgr.Group +} diff --git a/spn/cabin/module.go b/spn/cabin/module.go index 8d6a48974..93a5dd4f4 100644 --- a/spn/cabin/module.go +++ b/spn/cabin/module.go @@ -18,7 +18,7 @@ func (c *Cabin) Manager() *mgr.Manager { } func (c *Cabin) Start() error { - return prep() + return nil } func (c *Cabin) Stop() error { @@ -50,15 +50,16 @@ func New(instance instance) (*Cabin, error) { return nil, errors.New("only one instance allowed") } - if err := prep(); err != nil { - return nil, err - } - m := mgr.New("Cabin") module = &Cabin{ m: m, instance: instance, } + + if err := prep(); err != nil { + return nil, err + } + return module, nil } diff --git a/spn/captain/config.go b/spn/captain/config.go index d6fa308d6..dbfc2563c 100644 --- a/spn/captain/config.go +++ b/spn/captain/config.go @@ -56,8 +56,24 @@ var ( ) func prepConfig() error { - // Home Node Rules + // Register spn module setting. err := config.Register(&config.Option{ + Name: "SPN Module", + Key: CfgOptionEnableSPNKey, + Description: "Start the Safing Privacy Network module. If turned off, the SPN is fully disabled on this device.", + OptType: config.OptTypeBool, + DefaultValue: false, + Annotations: config.Annotations{ + config.DisplayOrderAnnotation: cfgOptionEnableSPNOrder, + config.CategoryAnnotation: "General", + }, + }) + if err != nil { + return err + } + + // Home Node Rules + err = config.Register(&config.Option{ Name: "Home Node Rules", Key: CfgOptionHomeHubPolicyKey, Description: `Customize which countries should or should not be used for your Home Node. The Home Node is your entry into the SPN. You connect directly to it and all your connections are routed through it. diff --git a/spn/captain/module.go b/spn/captain/module.go index c526b5ea2..9676762c3 100644 --- a/spn/captain/module.go +++ b/spn/captain/module.go @@ -47,10 +47,6 @@ func (c *Captain) Manager() *mgr.Manager { } func (c *Captain) Start() error { - if err := prep(); err != nil { - return err - } - return start() } @@ -97,6 +93,7 @@ func prep() error { } // Register API endpoints. + // FIXME(vladimir): Does this need to be called during start or during construction of module? if err := registerAPIEndpoints(); err != nil { return err } @@ -259,6 +256,11 @@ func New(instance instance, shutdownFunc func(exitCode int)) (*Captain, error) { EventSPNConnected: mgr.NewEventMgr[struct{}](SPNConnectedEvent, m), maintainPublicStatus: m.NewWorkerMgr("maintain public status", maintainPublicStatus, nil), } + + if err := prep(); err != nil { + return nil, err + } + return module, nil } From bb9c6458695b84ac5ec11b070789793d435922ef Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov Date: Tue, 9 Jul 2024 17:44:00 +0300 Subject: [PATCH 24/56] [WIP] Move API registeration into module construction --- base/metrics/api.go | 1 + service/broadcasts/module.go | 8 +++++--- service/compat/module.go | 6 +++--- service/core/core.go | 9 +++++---- service/intel/customlists/module.go | 14 ++++++++------ service/intel/geoip/module.go | 27 ++++++++++++++------------- service/netenv/main.go | 10 ++++++---- service/netquery/module_api.go | 7 +++---- service/network/module.go | 8 +++++--- service/process/module.go | 18 +++++++----------- service/profile/module.go | 8 ++++---- service/resolver/main.go | 6 +++--- service/sync/module.go | 7 ++++++- service/ui/module.go | 9 +++++---- service/updates/main.go | 4 ++++ service/updates/module.go | 7 +------ spn/access/module.go | 9 +++++---- spn/captain/api.go | 5 ++--- spn/captain/module.go | 1 - spn/navigator/module.go | 8 ++++---- 20 files changed, 91 insertions(+), 81 deletions(-) diff --git a/base/metrics/api.go b/base/metrics/api.go index ccc3bfe90..c06daca73 100644 --- a/base/metrics/api.go +++ b/base/metrics/api.go @@ -17,6 +17,7 @@ import ( func registerAPI() error { api.RegisterHandler("/metrics", &metricsAPI{}) + // FIXME(vladimir): This needs to be moved to the prep function. if err := api.RegisterEndpoint(api.Endpoint{ Name: "Export Registered Metrics", Description: "List all registered metrics with their metadata.", diff --git a/service/broadcasts/module.go b/service/broadcasts/module.go index bb162afdd..31ab162f4 100644 --- a/service/broadcasts/module.go +++ b/service/broadcasts/module.go @@ -20,9 +20,6 @@ func (b *Broadcasts) Manager() *mgr.Manager { } func (b *Broadcasts) Start() error { - if err := prep(); err != nil { - return err - } return start() } @@ -79,6 +76,11 @@ func New(instance instance) (*Broadcasts, error) { mgr: m, instance: instance, } + + if err := prep(); err != nil { + return nil, err + } + return module, nil } diff --git a/service/compat/module.go b/service/compat/module.go index 9cf6791bd..8b37bdd5c 100644 --- a/service/compat/module.go +++ b/service/compat/module.go @@ -26,9 +26,6 @@ func (u *Compat) Manager() *mgr.Manager { // Start starts the module. func (u *Compat) Start() error { - if err := prep(); err != nil { - return err - } return start() } @@ -161,6 +158,9 @@ func New(instance instance) (*Compat, error) { mgr: m, instance: instance, } + if err := prep(); err != nil { + return nil, err + } return module, nil } diff --git a/service/core/core.go b/service/core/core.go index d6812d6b1..d689b5ff2 100644 --- a/service/core/core.go +++ b/service/core/core.go @@ -37,10 +37,6 @@ func (c *Core) Manager() *mgr.Manager { } func (c *Core) Start() error { - if err := prep(); err != nil { - return err - } - return start() } @@ -133,6 +129,11 @@ func New(instance instance) (*Core, error) { EventShutdown: mgr.NewEventMgr[struct{}]("shutdown", m), EventRestart: mgr.NewEventMgr[struct{}]("restart", m), } + + if err := prep(); err != nil { + return nil, err + } + return module, nil } diff --git a/service/intel/customlists/module.go b/service/intel/customlists/module.go index 615a5f82d..ac0e0c0de 100644 --- a/service/intel/customlists/module.go +++ b/service/intel/customlists/module.go @@ -31,9 +31,6 @@ func (cl *CustomList) Manager() *mgr.Manager { } func (cl *CustomList) Start() error { - if err := prep(); err != nil { - return err - } return start() } @@ -61,6 +58,12 @@ var ( func prep() error { initFilterLists() + // Register the config in the ui. + err := registerConfig() + if err != nil { + return err + } + // Register api endpoint for updating the filter list. if err := api.RegisterEndpoint(api.Endpoint{ Path: "customlists/update", @@ -220,9 +223,8 @@ func New(instance instance) (*CustomList, error) { States: mgr.NewStateMgr(m), updateFilterListWorkerMgr: m.NewWorkerMgr("update custom filter list", checkAndUpdateFilterList, nil), } - // Register the config in the ui. - err := registerConfig() - if err != nil { + + if err := prep(); err != nil { return nil, err } diff --git a/service/intel/geoip/module.go b/service/intel/geoip/module.go index e8527e4ee..6c2bb55ef 100644 --- a/service/intel/geoip/module.go +++ b/service/intel/geoip/module.go @@ -19,19 +19,6 @@ func (g *GeoIP) Manager() *mgr.Manager { } func (g *GeoIP) Start() error { - if err := api.RegisterEndpoint(api.Endpoint{ - Path: "intel/geoip/countries", - Read: api.PermitUser, - // Do not attach to module, as the data is always available anyway. - StructFunc: func(ar *api.Request) (i interface{}, err error) { - return countries, nil - }, - Name: "Get Country Information", - Description: "Returns a map of country information centers indexed by ISO-A2 country code", - }); err != nil { - return err - } - module.instance.Updates().EventResourcesUpdated.AddCallback( "Check for GeoIP database updates", func(_ *mgr.WorkerCtx, _ struct{}) (bool, error) { @@ -61,6 +48,20 @@ func New(instance instance) (*GeoIP, error) { mgr: m, instance: instance, } + + if err := api.RegisterEndpoint(api.Endpoint{ + Path: "intel/geoip/countries", + Read: api.PermitUser, + // Do not attach to module, as the data is always available anyway. + StructFunc: func(ar *api.Request) (i interface{}, err error) { + return countries, nil + }, + Name: "Get Country Information", + Description: "Returns a map of country information centers indexed by ISO-A2 country code", + }); err != nil { + return nil, err + } + return module, nil } diff --git a/service/netenv/main.go b/service/netenv/main.go index 44848a5fd..608b1a066 100644 --- a/service/netenv/main.go +++ b/service/netenv/main.go @@ -30,10 +30,6 @@ func (ne *NetEnv) Manager() *mgr.Manager { } func (ne *NetEnv) Start() error { - if err := prep(); err != nil { - return err - } - ne.m.Go( "monitor network changes", monitorNetworkChanges, @@ -52,6 +48,7 @@ func (ne *NetEnv) Stop() error { } func prep() error { + // FIXME(vladimir): Does this need to be in the prep function. checkForIPv6Stack() if err := registerAPIEndpoints(); err != nil { @@ -62,6 +59,7 @@ func prep() error { return err } + // FIXME(vladimir): Does this need to be in the prep function. return prepLocation() } @@ -102,6 +100,10 @@ func New(instance instance) (*NetEnv, error) { EventNetworkChange: mgr.NewEventMgr[struct{}]("network change", m), EventOnlineStatusChange: mgr.NewEventMgr[OnlineStatus]("online status change", m), } + if err := prep(); err != nil { + return nil, err + } + return module, nil } diff --git a/service/netquery/module_api.go b/service/netquery/module_api.go index 518d260a7..7d7047613 100644 --- a/service/netquery/module_api.go +++ b/service/netquery/module_api.go @@ -204,10 +204,6 @@ func (nq *NetQuery) Manager() *mgr.Manager { } func (nq *NetQuery) Start() error { - if err := nq.prepare(); err != nil { - return fmt.Errorf("failed to prepare netquery module: %w", err) - } - nq.mgr.Go("netquery connection feed listener", func(ctx *mgr.WorkerCtx) error { sub, err := nq.db.Subscribe(query.New("network:")) if err != nil { @@ -326,6 +322,9 @@ func NewModule(instance instance) (*NetQuery, error) { mgr: m, instance: instance, } + if err := module.prepare(); err != nil { + return nil, fmt.Errorf("failed to prepare netquery module: %w", err) + } return module, nil } diff --git a/service/network/module.go b/service/network/module.go index 396e81533..4cab1cb15 100644 --- a/service/network/module.go +++ b/service/network/module.go @@ -35,9 +35,6 @@ func (n *Network) Manager() *mgr.Manager { } func (n *Network) Start() error { - if err := prep(); err != nil { - return err - } return start() } @@ -181,6 +178,11 @@ func New(instance instance) (*Network, error) { instance: instance, EventConnectionReattributed: mgr.NewEventMgr[string](ConnectionReattributedEvent, m), } + + if err := prep(); err != nil { + return nil, err + } + return module, nil } diff --git a/service/process/module.go b/service/process/module.go index ccab72d55..563368ab4 100644 --- a/service/process/module.go +++ b/service/process/module.go @@ -19,7 +19,11 @@ func (pm *ProcessModule) Manager() *mgr.Manager { } func (pm *ProcessModule) Start() error { - return start() + updatesPath = updates.RootPath() + if updatesPath != "" { + updatesPath += string(os.PathSeparator) + } + return nil } func (pm *ProcessModule) Stop() error { @@ -29,13 +33,8 @@ func (pm *ProcessModule) Stop() error { var updatesPath string func prep() error { - return registerConfiguration() -} - -func start() error { - updatesPath = updates.RootPath() - if updatesPath != "" { - updatesPath += string(os.PathSeparator) + if err := registerConfiguration(); err != nil { + return err } if err := registerAPIEndpoints(); err != nil { @@ -56,9 +55,6 @@ func New(instance instance) (*ProcessModule, error) { return nil, errors.New("only one instance allowed") } - if err := prep(); err != nil { - return nil, err - } m := mgr.New("ProcessModule") module = &ProcessModule{ mgr: m, diff --git a/service/profile/module.go b/service/profile/module.go index 8f08008aa..5798b00d2 100644 --- a/service/profile/module.go +++ b/service/profile/module.go @@ -65,6 +65,10 @@ func prep() error { return err } + if err := registerAPIEndpoints(); err != nil { + return err + } + // Setup icon storage location. iconsDir := dataroot.Root().ChildDir("databases", 0o0700).ChildDir("icons", 0o0700) if err := iconsDir.Ensure(); err != nil { @@ -115,10 +119,6 @@ func start() error { log.Warningf("profile: error during loading global profile from configuration: %s", err) } - if err := registerAPIEndpoints(); err != nil { - return err - } - return nil } diff --git a/service/resolver/main.go b/service/resolver/main.go index 69e4c055e..9ffe8e44b 100644 --- a/service/resolver/main.go +++ b/service/resolver/main.go @@ -36,9 +36,6 @@ func (rm *ResolverModule) Manager() *mgr.Manager { } func (rm *ResolverModule) Start() error { - if err := prep(); err != nil { - return err - } return start() } @@ -265,6 +262,9 @@ func New(instance instance) (*ResolverModule, error) { States: mgr.NewStateMgr(m), } + if err := prep(); err != nil { + return nil, err + } return module, nil } diff --git a/service/sync/module.go b/service/sync/module.go index c6a8c9d1f..06f50d2e5 100644 --- a/service/sync/module.go +++ b/service/sync/module.go @@ -18,7 +18,7 @@ func (s *Sync) Manager() *mgr.Manager { } func (s *Sync) Start() error { - return prep() + return nil } func (s *Sync) Stop() error { @@ -58,6 +58,11 @@ func New(instance instance) (*Sync, error) { mgr: m, instance: instance, } + + if err := prep(); err != nil { + return nil, err + } + return module, nil } diff --git a/service/ui/module.go b/service/ui/module.go index 09639ba22..630808e50 100644 --- a/service/ui/module.go +++ b/service/ui/module.go @@ -48,10 +48,6 @@ func (ui *UI) Manager() *mgr.Manager { // Start starts the module. func (ui *UI) Start() error { - if err := prep(); err != nil { - return err - } - return start() } @@ -72,6 +68,11 @@ func New(instance instance) (*UI, error) { mgr: m, instance: instance, } + + if err := prep(); err != nil { + return nil, err + } + return module, nil } diff --git a/service/updates/main.go b/service/updates/main.go index a2d0167b8..279529c97 100644 --- a/service/updates/main.go +++ b/service/updates/main.go @@ -96,6 +96,10 @@ func prep() error { } } + if err := registerConfig(); err != nil { + return err + } + return registerAPIEndpoints() } diff --git a/service/updates/module.go b/service/updates/module.go index df8e4fdd1..0487d2c14 100644 --- a/service/updates/module.go +++ b/service/updates/module.go @@ -47,8 +47,7 @@ func New(instance instance, shutdownFunc func(exitCode int)) (*Updates, error) { instance: instance, shutdownFunc: shutdownFunc, } - - if err := registerConfig(); err != nil { + if err := prep(); err != nil { return nil, err } @@ -67,10 +66,6 @@ func (u *Updates) Manager() *mgr.Manager { // Start starts the module. func (u *Updates) Start() error { - if err := prep(); err != nil { - return err - } - return start() } diff --git a/spn/access/module.go b/spn/access/module.go index b62599c20..92857bbc0 100644 --- a/spn/access/module.go +++ b/spn/access/module.go @@ -30,10 +30,6 @@ func (a *Access) Manager() *mgr.Manager { } func (a *Access) Start() error { - if err := prep(); err != nil { - return err - } - return start() } @@ -233,6 +229,11 @@ func New(instance instance) (*Access, error) { EventAccountUpdate: mgr.NewEventMgr[struct{}](AccountUpdateEvent, m), updateAccountWorkerMgr: m.NewWorkerMgr("update account", UpdateAccount, nil), } + + if err := prep(); err != nil { + return nil, err + } + return module, nil } diff --git a/spn/captain/api.go b/spn/captain/api.go index a3108facc..fc19136f5 100644 --- a/spn/captain/api.go +++ b/spn/captain/api.go @@ -10,9 +10,8 @@ const ( func registerAPIEndpoints() error { if err := api.RegisterEndpoint(api.Endpoint{ - Path: apiPathForSPNReInit, - Write: api.PermitAdmin, - // BelongsTo: module, // Do not attach to module, as this must run outside of the module. + Path: apiPathForSPNReInit, + Write: api.PermitAdmin, ActionFunc: handleReInit, Name: "Re-initialize SPN", Description: "Stops the SPN, resets all caches and starts it again. The SPN account and settings are not changed.", diff --git a/spn/captain/module.go b/spn/captain/module.go index 9676762c3..d66b2763b 100644 --- a/spn/captain/module.go +++ b/spn/captain/module.go @@ -93,7 +93,6 @@ func prep() error { } // Register API endpoints. - // FIXME(vladimir): Does this need to be called during start or during construction of module? if err := registerAPIEndpoints(); err != nil { return err } diff --git a/spn/navigator/module.go b/spn/navigator/module.go index 4392b98dc..0d8f235c1 100644 --- a/spn/navigator/module.go +++ b/spn/navigator/module.go @@ -45,10 +45,6 @@ func (n *Navigator) Manager() *mgr.Manager { } func (n *Navigator) Start() error { - if err := prep(); err != nil { - return err - } - return start() } @@ -153,6 +149,10 @@ func New(instance instance) (*Navigator, error) { mgr: m, instance: instance, } + if err := prep(); err != nil { + return nil, err + } + return module, nil } From 0bcdd5ce6723f5b8b2234c1bd24a2f101fa2951f Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov Date: Thu, 11 Jul 2024 14:05:04 +0300 Subject: [PATCH 25/56] [WIP] Update states mgr for all modules --- base/config/main.go | 1 - base/metrics/module.go | 9 ++++----- base/notifications/module.go | 12 +++++++----- service/compat/module.go | 7 +++++++ service/intel/customlists/lists.go | 8 ++++---- service/intel/customlists/module.go | 8 ++++++-- service/intel/filterlists/module.go | 10 +++++++--- service/intel/filterlists/updater.go | 8 ++++---- service/mgr/states.go | 4 ++++ service/nameserver/module.go | 10 +++++++--- service/profile/config-update.go | 4 ++-- service/profile/migrations.go | 4 ++-- service/profile/module.go | 8 ++++++-- service/resolver/main.go | 10 +++++++--- service/updates/module.go | 4 ++-- spn/captain/client.go | 6 +++--- spn/captain/module.go | 8 ++++++-- 17 files changed, 78 insertions(+), 43 deletions(-) diff --git a/base/config/main.go b/base/config/main.go index a1b3b19fd..92e546252 100644 --- a/base/config/main.go +++ b/base/config/main.go @@ -19,7 +19,6 @@ import ( const ChangeEvent = "config change" var ( - // module *modules.Module dataRoot *utils.DirStructure exportConfig bool diff --git a/base/metrics/module.go b/base/metrics/module.go index d77a982f3..ccabb4236 100644 --- a/base/metrics/module.go +++ b/base/metrics/module.go @@ -85,12 +85,8 @@ func start() error { return err } - if err := registerAPI(); err != nil { - return err - } - if pushOption() != "" { - module.mgr.Do("metric pusher", metricsWriter) + _ = module.mgr.Do("metric pusher", metricsWriter) } return nil @@ -203,6 +199,9 @@ func New(instance instance) (*Metrics, error) { return nil, err } + if err := registerAPI(); err != nil { + return nil, err + } return module, nil } diff --git a/base/notifications/module.go b/base/notifications/module.go index a3b7d835c..f69d017ee 100644 --- a/base/notifications/module.go +++ b/base/notifications/module.go @@ -13,16 +13,18 @@ type Notifications struct { mgr *mgr.Manager instance instance - States *mgr.StateMgr + states *mgr.StateMgr } func (n *Notifications) Manager() *mgr.Manager { return n.mgr } -func (n *Notifications) Start() error { - n.States = mgr.NewStateMgr(n.mgr) +func (n *Notifications) States() *mgr.StateMgr { + return n.states +} +func (n *Notifications) Start() error { if err := prep(); err != nil { return err } @@ -57,7 +59,7 @@ func showConfigLoadingErrors() { } // Trigger a module error for more awareness. - module.States.Add(mgr.State{ + module.states.Add(mgr.State{ ID: "config:validation-errors-on-load", Name: "Invalid Settings", Message: "Some current settings are invalid. Please update them and restart the Portmaster.", @@ -100,7 +102,7 @@ func New(instance instance) (*Notifications, error) { mgr: m, instance: instance, - States: mgr.NewStateMgr(m), + states: mgr.NewStateMgr(m), } return module, nil diff --git a/service/compat/module.go b/service/compat/module.go index 8b37bdd5c..5726b355d 100644 --- a/service/compat/module.go +++ b/service/compat/module.go @@ -18,12 +18,17 @@ type Compat struct { instance instance selfcheckWorkerMgr *mgr.WorkerMgr + states *mgr.StateMgr } func (u *Compat) Manager() *mgr.Manager { return u.mgr } +func (u *Compat) States() *mgr.StateMgr { + return u.states +} + // Start starts the module. func (u *Compat) Start() error { return start() @@ -157,6 +162,8 @@ func New(instance instance) (*Compat, error) { module = &Compat{ mgr: m, instance: instance, + + states: mgr.NewStateMgr(m), } if err := prep(); err != nil { return nil, err diff --git a/service/intel/customlists/lists.go b/service/intel/customlists/lists.go index 66935991e..797d0723c 100644 --- a/service/intel/customlists/lists.go +++ b/service/intel/customlists/lists.go @@ -80,7 +80,7 @@ func parseFile(filePath string) error { file, err := os.Open(filePath) if err != nil { log.Warningf("intel/customlists: failed to parse file %s", err) - module.States.Add(mgr.State{ + module.states.Add(mgr.State{ ID: parseWarningNotificationID, Name: "Failed to open custom filter list", Message: err.Error(), @@ -113,7 +113,7 @@ func parseFile(filePath string) error { if invalidLinesRation > rationForInvalidLinesUntilWarning { log.Warning("intel/customlists: Too many invalid lines") - module.States.Add(mgr.State{ + module.states.Add(mgr.State{ ID: zeroIPNotificationID, Name: "Custom filter list has many invalid lines", Message: fmt.Sprintf(`%d out of %d lines are invalid. @@ -121,7 +121,7 @@ func parseFile(filePath string) error { Type: mgr.StateTypeWarning, }) } else { - module.States.Remove(zeroIPNotificationID) + module.states.Remove(zeroIPNotificationID) } allEntriesCount := len(domainsFilterList) + len(ipAddressesFilterList) + len(autonomousSystemsFilterList) + len(countryCodesFilterList) @@ -140,7 +140,7 @@ func parseFile(filePath string) error { len(autonomousSystemsFilterList), len(countryCodesFilterList))) - module.States.Remove(parseWarningNotificationID) + module.states.Remove(parseWarningNotificationID) return nil } diff --git a/service/intel/customlists/module.go b/service/intel/customlists/module.go index ac0e0c0de..08b21d7f7 100644 --- a/service/intel/customlists/module.go +++ b/service/intel/customlists/module.go @@ -23,13 +23,17 @@ type CustomList struct { updateFilterListWorkerMgr *mgr.WorkerMgr - States *mgr.StateMgr + states *mgr.StateMgr } func (cl *CustomList) Manager() *mgr.Manager { return cl.mgr } +func (cl *CustomList) States() *mgr.StateMgr { + return cl.states +} + func (cl *CustomList) Start() error { return start() } @@ -220,7 +224,7 @@ func New(instance instance) (*CustomList, error) { mgr: m, instance: instance, - States: mgr.NewStateMgr(m), + states: mgr.NewStateMgr(m), updateFilterListWorkerMgr: m.NewWorkerMgr("update custom filter list", checkAndUpdateFilterList, nil), } diff --git a/service/intel/filterlists/module.go b/service/intel/filterlists/module.go index c499ef209..c529c207b 100644 --- a/service/intel/filterlists/module.go +++ b/service/intel/filterlists/module.go @@ -23,13 +23,17 @@ type FilterLists struct { mgr *mgr.Manager instance instance - States *mgr.StateMgr + states *mgr.StateMgr } func (fl *FilterLists) Manager() *mgr.Manager { return fl.mgr } +func (fl *FilterLists) States() *mgr.StateMgr { + return fl.states +} + func (fl *FilterLists) Start() error { if err := prep(); err != nil { return err @@ -110,7 +114,7 @@ func stop() error { } func warnAboutDisabledFilterLists() { - module.States.Add(mgr.State{ + module.states.Add(mgr.State{ ID: filterlistsDisabled, Name: "Filter Lists Are Initializing", Message: "Filter lists are being downloaded and set up in the background. They will be activated as configured when finished.", @@ -133,7 +137,7 @@ func New(instance instance) (*FilterLists, error) { mgr: m, instance: instance, - States: mgr.NewStateMgr(m), + states: mgr.NewStateMgr(m), } return module, nil } diff --git a/service/intel/filterlists/updater.go b/service/intel/filterlists/updater.go index 8d3b19237..c0be14dac 100644 --- a/service/intel/filterlists/updater.go +++ b/service/intel/filterlists/updater.go @@ -34,13 +34,13 @@ func tryListUpdate(ctx context.Context) error { // generic one with the returned error. hasWarningState := false - for _, state := range module.States.Export().States { + for _, state := range module.states.Export().States { if state.Type == mgr.StateTypeWarning { hasWarningState = true } } if !hasWarningState { - module.States.Add(mgr.State{ + module.states.Add(mgr.State{ ID: filterlistsUpdateFailed, Name: "Filter Lists Update Failed", Message: fmt.Sprintf("The Portmaster failed to process a filter lists update. Filtering capabilities are currently either impaired or not available at all. Error: %s", err.Error()), @@ -134,7 +134,7 @@ func performUpdate(ctx context.Context) error { // if we failed to remove all stale cache entries // we abort now WITHOUT updating the database version. This means // we'll try again during the next update. - module.States.Add(mgr.State{ + module.states.Add(mgr.State{ ID: filterlistsStaleDataSurvived, Name: "Filter Lists May Overblock", Message: fmt.Sprintf("The Portmaster failed to delete outdated filter list data. Filtering capabilities are fully available, but overblocking may occur. Error: %s", err.Error()), //nolint:misspell // overblocking != overclocking @@ -153,7 +153,7 @@ func performUpdate(ctx context.Context) error { } // The list update succeeded, resolve any states. - module.States.Clear() + module.states.Clear() return nil } diff --git a/service/mgr/states.go b/service/mgr/states.go index e6f049219..feba5280e 100644 --- a/service/mgr/states.go +++ b/service/mgr/states.go @@ -43,6 +43,10 @@ type StateUpdate struct { States []State } +type StatefulModule interface { + States() *StateMgr +} + // NewStateMgr returns a new state manager. func NewStateMgr(mgr *Manager) *StateMgr { return &StateMgr{ diff --git a/service/nameserver/module.go b/service/nameserver/module.go index 248b90abf..f2f3e4305 100644 --- a/service/nameserver/module.go +++ b/service/nameserver/module.go @@ -23,13 +23,17 @@ type NameServer struct { mgr *mgr.Manager instance instance - States *mgr.StateMgr + states *mgr.StateMgr } func (ns *NameServer) Manager() *mgr.Manager { return ns.mgr } +func (ns *NameServer) States() *mgr.StateMgr { + return ns.states +} + func (ns *NameServer) Start() error { return start() } @@ -156,7 +160,7 @@ func startListener(ip net.IP, port uint16, first bool) { // Resolve generic listener error, if primary listener. if first { - module.States.Remove(eventIDListenerFailed) + module.states.Remove(eventIDListenerFailed) } // Start listening. @@ -320,7 +324,7 @@ func New(instance instance) (*NameServer, error) { mgr: m, instance: instance, - States: mgr.NewStateMgr(m), + states: mgr.NewStateMgr(m), } if err := prep(); err != nil { return nil, err diff --git a/service/profile/config-update.go b/service/profile/config-update.go index 94c64cf0d..acbe7b85b 100644 --- a/service/profile/config-update.go +++ b/service/profile/config-update.go @@ -129,7 +129,7 @@ func updateGlobalConfigProfile(_ context.Context) error { // If there was any error, try again later until it succeeds. if lastErr == nil { - module.States.Remove(globalConfigProfileErrorID) + module.states.Remove(globalConfigProfileErrorID) } else { // Create task after first failure. @@ -140,7 +140,7 @@ func updateGlobalConfigProfile(_ context.Context) error { }) // Add module warning to inform user. - module.States.Add(mgr.State{ + module.states.Add(mgr.State{ ID: globalConfigProfileErrorID, Name: "Internal Settings Failure", Message: fmt.Sprintf("Some global settings might not be applied correctly. You can try restarting the Portmaster to resolve this problem. Error: %s", err), diff --git a/service/profile/migrations.go b/service/profile/migrations.go index 72829b0ce..d081a3fb3 100644 --- a/service/profile/migrations.go +++ b/service/profile/migrations.go @@ -130,7 +130,7 @@ func migrateIcons(ctx context.Context, _, to *version.Version, db *database.Inte if lastErr != nil { // Normally, an icon migration would not be such a big error, but this is a test // run for the profile IDs and we absolutely need to know if anything went wrong. - module.States.Add(mgr.State{ + module.states.Add(mgr.State{ ID: "migration-failed", Name: "Profile Migration Failed", Message: fmt.Sprintf("Failed to migrate icons of %d profiles (out of %d pending). The last error was: %s\n\nPlease restart Portmaster to try the migration again.", failed, total, lastErr), @@ -219,7 +219,7 @@ func migrateToDerivedIDs(ctx context.Context, _, to *version.Version, db *databa // Log migration failure and try again next time. if lastErr != nil { - module.States.Add(mgr.State{ + module.states.Add(mgr.State{ ID: "migration-failed", Name: "Profile Migration Failed", Message: fmt.Sprintf("Failed to migrate profile IDs of %d profiles (out of %d pending). The last error was: %s\n\nPlease restart Portmaster to try the migration again.", failed, total, lastErr), diff --git a/service/profile/module.go b/service/profile/module.go index 5798b00d2..01659c761 100644 --- a/service/profile/module.go +++ b/service/profile/module.go @@ -37,13 +37,17 @@ type ProfileModule struct { EventDelete *mgr.EventMgr[string] EventMigrated *mgr.EventMgr[[]string] - States *mgr.StateMgr + states *mgr.StateMgr } func (pm *ProfileModule) Manager() *mgr.Manager { return pm.mgr } +func (pm *ProfileModule) States() *mgr.StateMgr { + return pm.states +} + func (pm *ProfileModule) Start() error { return start() } @@ -144,7 +148,7 @@ func NewModule(instance instance) (*ProfileModule, error) { EventDelete: mgr.NewEventMgr[string](DeletedEvent, m), EventMigrated: mgr.NewEventMgr[[]string](MigratedEvent, m), - States: mgr.NewStateMgr(m), + states: mgr.NewStateMgr(m), } if err := prep(); err != nil { diff --git a/service/resolver/main.go b/service/resolver/main.go index 9ffe8e44b..46c26ab02 100644 --- a/service/resolver/main.go +++ b/service/resolver/main.go @@ -28,13 +28,17 @@ type ResolverModule struct { failingResolverWorkerMgr *mgr.WorkerMgr suggestUsingStaleCacheTask *mgr.WorkerMgr - States *mgr.StateMgr + states *mgr.StateMgr } func (rm *ResolverModule) Manager() *mgr.Manager { return rm.mgr } +func (rm *ResolverModule) States() *mgr.StateMgr { + return rm.states +} + func (rm *ResolverModule) Start() error { return start() } @@ -203,7 +207,7 @@ func resetFailingResolversNotification() { } // Additionally, resolve the module error, if not done through the notification. - module.States.Remove(failingResolverErrorID) + module.states.Remove(failingResolverErrorID) } // AddToDebugInfo adds the system status to the given debug.Info. @@ -260,7 +264,7 @@ func New(instance instance) (*ResolverModule, error) { mgr: m, instance: instance, - States: mgr.NewStateMgr(m), + states: mgr.NewStateMgr(m), } if err := prep(); err != nil { return nil, err diff --git a/service/updates/module.go b/service/updates/module.go index 0487d2c14..43c2a48cc 100644 --- a/service/updates/module.go +++ b/service/updates/module.go @@ -54,8 +54,8 @@ func New(instance instance, shutdownFunc func(exitCode int)) (*Updates, error) { return module, nil } -// State returns the state manager. -func (u *Updates) State() *mgr.StateMgr { +// States returns the state manager. +func (u *Updates) States() *mgr.StateMgr { return u.states } diff --git a/spn/captain/client.go b/spn/captain/client.go index 4edc6264d..49bcf7828 100644 --- a/spn/captain/client.go +++ b/spn/captain/client.go @@ -77,11 +77,11 @@ func clientManager(ctx *mgr.WorkerCtx) error { ready.UnSet() netenv.ConnectedToSPN.UnSet() resetSPNStatus(StatusDisabled, true) - module.States.Clear() + module.states.Clear() clientStopHomeHub(ctx.Ctx()) }() - module.States.Add(mgr.State{ + module.states.Add(mgr.State{ ID: "spn:establishing-home-hub", Name: "Connecting to SPN...", Message: "Connecting to the SPN network is in progress.", @@ -423,7 +423,7 @@ func clientSetActiveConnectionStatus(ctx context.Context) clientComponentResult } // Resolve any connection error. - module.States.Clear() + module.states.Clear() // Update SPN Status with connection information, if not already correctly set. spnStatus.Lock() diff --git a/spn/captain/module.go b/spn/captain/module.go index d66b2763b..a8ba0e567 100644 --- a/spn/captain/module.go +++ b/spn/captain/module.go @@ -38,7 +38,7 @@ type Captain struct { healthCheckTicker *mgr.SleepyTicker maintainPublicStatus *mgr.WorkerMgr - States *mgr.StateMgr + states *mgr.StateMgr EventSPNConnected *mgr.EventMgr[struct{}] } @@ -46,6 +46,10 @@ func (c *Captain) Manager() *mgr.Manager { return c.mgr } +func (c *Captain) States() *mgr.StateMgr { + return c.states +} + func (c *Captain) Start() error { return start() } @@ -251,7 +255,7 @@ func New(instance instance, shutdownFunc func(exitCode int)) (*Captain, error) { instance: instance, shutdownFunc: shutdownFunc, - States: mgr.NewStateMgr(m), + states: mgr.NewStateMgr(m), EventSPNConnected: mgr.NewEventMgr[struct{}](SPNConnectedEvent, m), maintainPublicStatus: m.NewWorkerMgr("maintain public status", maintainPublicStatus, nil), } From b60c7362d92c916ecadddd315f78c0fd5db53876 Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov Date: Thu, 11 Jul 2024 14:05:53 +0300 Subject: [PATCH 26/56] [WIP] Add CmdLine operation support --- base/api/main.go | 3 +-- base/api/module.go | 1 + base/config/main.go | 6 +----- base/config/module.go | 4 +++- base/metrics/api.go | 1 - cmds/portmaster-core/main.go | 13 ++++++++++++- service/instance.go | 15 ++++++++++++--- service/resolver/resolvers.go | 6 +++--- 8 files changed, 33 insertions(+), 16 deletions(-) diff --git a/base/api/main.go b/base/api/main.go index 62b32ac7b..83c97c0f4 100644 --- a/base/api/main.go +++ b/base/api/main.go @@ -26,8 +26,7 @@ func init() { func prep() error { if exportEndpoints { - // FIXME(vladimir): migrate - // modules.SetCmdLineOperation(exportEndpointsCmd) + module.instance.SetCmdLineOperation(exportEndpointsCmd) } if getDefaultListenAddress() == "" { diff --git a/base/api/module.go b/base/api/module.go index 1f912ba42..ba4a038f0 100644 --- a/base/api/module.go +++ b/base/api/module.go @@ -59,4 +59,5 @@ func New(instance instance) (*API, error) { type instance interface { Config() *config.Config + SetCmdLineOperation(f func() error) } diff --git a/base/config/main.go b/base/config/main.go index 92e546252..528e22871 100644 --- a/base/config/main.go +++ b/base/config/main.go @@ -32,9 +32,6 @@ func SetDataRoot(root *utils.DirStructure) { } func init() { - // module = modules.Register("config", prep, start, nil, "database") - // module.RegisterEvent(ChangeEvent, true) - flag.BoolVar(&exportConfig, "export-config-options", false, "export configuration registry and exit") } @@ -45,8 +42,7 @@ func prep() error { } if exportConfig { - // FIXME(vladimir): migrate - // modules.SetCmdLineOperation(exportConfigCmd) + module.instance.SetCmdLineOperation(exportConfigCmd) } return registerBasicOptions() diff --git a/base/config/module.go b/base/config/module.go index e0d68fdba..afd7cc397 100644 --- a/base/config/module.go +++ b/base/config/module.go @@ -54,4 +54,6 @@ func New(instance instance) (*Config, error) { return module, nil } -type instance interface{} +type instance interface { + SetCmdLineOperation(f func() error) +} diff --git a/base/metrics/api.go b/base/metrics/api.go index c06daca73..ccc3bfe90 100644 --- a/base/metrics/api.go +++ b/base/metrics/api.go @@ -17,7 +17,6 @@ import ( func registerAPI() error { api.RegisterHandler("/metrics", &metricsAPI{}) - // FIXME(vladimir): This needs to be moved to the prep function. if err := api.RegisterEndpoint(api.Endpoint{ Name: "Export Registered Metrics", Description: "List all registered metrics with their metadata.", diff --git a/cmds/portmaster-core/main.go b/cmds/portmaster-core/main.go index b61173cbe..ab3d8a0f6 100644 --- a/cmds/portmaster-core/main.go +++ b/cmds/portmaster-core/main.go @@ -39,7 +39,7 @@ func main() { // Set default log level. log.SetLogLevel(log.WarningLevel) - log.Start() + _ = log.Start() // Configure metrics. _ = metrics.SetNamespace("portmaster") @@ -67,6 +67,17 @@ func main() { fmt.Printf("error creating an instance: %s\n", err) return } + + // execute command if available + if instance.CommandLineOperation != nil { + // Run the function and exit. + if err != nil { + fmt.Fprintf(os.Stderr, "cmdline operation failed: %s\n", err) + os.Exit(1) + } + os.Exit(0) + } + // Start go func() { err = instance.Group.Start() diff --git a/service/instance.go b/service/instance.go index 6291dd7c7..9a2052d6e 100644 --- a/service/instance.go +++ b/service/instance.go @@ -45,10 +45,9 @@ import ( // Instance is an instance of a portmaste service. type Instance struct { - *mgr.Group - version string + *mgr.Group database *dbmodule.DBModule config *config.Config api *api.API @@ -91,6 +90,8 @@ type Instance struct { ships *ships.Ships sluice *sluice.SluiceModule terminal *terminal.TerminalModule + + CommandLineOperation func() error } // New returns a new portmaster service instance. @@ -511,7 +512,15 @@ func (i *Instance) SPNGroup() *mgr.Group { } // Events -// SPN connected + +// GetEventSPNConnected return the event manager for the SPN connected event. func (i *Instance) GetEventSPNConnected() *mgr.EventMgr[struct{}] { return i.captain.EventSPNConnected } + +// Special functions + +// SetCmdLineOperation sets a command line operation to be executed instead of starting the system. This is useful when functions need all modules to be prepared for a special operation. +func (i *Instance) SetCmdLineOperation(f func() error) { + i.CommandLineOperation = f +} diff --git a/service/resolver/resolvers.go b/service/resolver/resolvers.go index 4c2335201..22170082c 100644 --- a/service/resolver/resolvers.go +++ b/service/resolver/resolvers.go @@ -370,7 +370,7 @@ func loadResolvers() { defer resolversLock.Unlock() // Resolve module error about missing resolvers. - module.States.Remove(missingResolversErrorID) + module.states.Remove(missingResolversErrorID) // Check if settings were changed and clear name cache when they did. newResolverConfig := configuredNameServers() @@ -394,7 +394,7 @@ func loadResolvers() { newResolvers = getConfiguredResolvers(defaultNameServers) if len(newResolvers) > 0 { log.Warning("resolver: no (valid) dns server found in config or system, falling back to global defaults") - module.States.Add(mgr.State{ + module.states.Add(mgr.State{ ID: missingResolversErrorID, Name: "Using Factory Default DNS Servers", Message: "The Portmaster could not find any (valid) DNS servers in the settings or system. In order to prevent being disconnected, the factory defaults are being used instead. If you just switched your network, this should be resolved shortly.", @@ -402,7 +402,7 @@ func loadResolvers() { }) } else { log.Critical("resolver: no (valid) dns server found in config, system or global defaults") - module.States.Add(mgr.State{ + module.states.Add(mgr.State{ ID: missingResolversErrorID, Name: "No DNS Servers Configured", Message: "The Portmaster could not find any (valid) DNS servers in the settings or system. You will experience severe connectivity problems until resolved. If you just switched your network, this should be resolved shortly.", From 482f2ffd29b3504f9a05fb4aff54a5f32eb5f9d9 Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 16 Jul 2024 16:38:51 +0200 Subject: [PATCH 27/56] Add state helper methods to module group and instance --- service/instance.go | 19 +++++++++++++++++++ service/mgr/module.go | 21 +++++++++++++++++++++ service/mgr/states.go | 20 ++++++++++++++++++-- 3 files changed, 58 insertions(+), 2 deletions(-) diff --git a/service/instance.go b/service/instance.go index 9a2052d6e..da266b041 100644 --- a/service/instance.go +++ b/service/instance.go @@ -524,3 +524,22 @@ func (i *Instance) GetEventSPNConnected() *mgr.EventMgr[struct{}] { func (i *Instance) SetCmdLineOperation(f func() error) { i.CommandLineOperation = f } + +// GetStatus returns the current Status of all group modules. +func (i *Instance) GetStatus() []mgr.StateUpdate { + mainStates := i.Group.GetStatus() + spnStates := i.SpnGroup.GetStatus() + + updates := make([]mgr.StateUpdate, 0, len(mainStates)+len(spnStates)) + updates = append(updates, mainStates...) + updates = append(updates, spnStates...) + + return updates +} + +// AddStatusCallback adds the given callback function to all group modules that +// expose a state manager at States(). +func (i *Instance) AddStatusCallback(callbackName string, callback mgr.EventCallbackFunc[mgr.StateUpdate]) { + i.Group.AddStatusCallback(callbackName, callback) + i.SpnGroup.AddStatusCallback(callbackName, callback) +} diff --git a/service/mgr/module.go b/service/mgr/module.go index 93de1efa1..d56135dd5 100644 --- a/service/mgr/module.go +++ b/service/mgr/module.go @@ -196,6 +196,27 @@ func (g *Group) IsDone() bool { return g.ctx.Err() != nil } +// GetStatus returns the current Status of all group modules. +func (g *Group) GetStatus() []StateUpdate { + updates := make([]StateUpdate, 0, len(g.modules)) + for _, gm := range g.modules { + if stateful, ok := gm.module.(StatefulModule); ok { + updates = append(updates, stateful.States().Export()) + } + } + return updates +} + +// AddStatusCallback adds the given callback function to all group modules that +// expose a state manager at States(). +func (g *Group) AddStatusCallback(callbackName string, callback EventCallbackFunc[StateUpdate]) { + for _, gm := range g.modules { + if stateful, ok := gm.module.(StatefulModule); ok { + stateful.States().AddCallback(callbackName, callback) + } + } +} + // RunModules is a simple wrapper function to start modules and stop them again // when the given context is canceled. func RunModules(ctx context.Context, modules ...Module) error { diff --git a/service/mgr/states.go b/service/mgr/states.go index feba5280e..5ecb456a5 100644 --- a/service/mgr/states.go +++ b/service/mgr/states.go @@ -37,9 +37,25 @@ const ( StateTypeError = "error" ) +// Severity returns a number representing the gravity of the state for ordering. +func (st StateType) Severity() int { + switch st { + case StateTypeUndefined: + return 0 + case StateTypeHint: + return 1 + case StateTypeWarning: + return 2 + case StateTypeError: + return 3 + default: + return 0 + } +} + // StateUpdate is used to update others about a state change. type StateUpdate struct { - Name string + Module string States []State } @@ -121,7 +137,7 @@ func (m *StateMgr) export() StateUpdate { } return StateUpdate{ - Name: name, + Module: name, States: slices.Clone(m.states), } } From d142536789859dbed9a582abe0daefe17f3116f1 Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 16 Jul 2024 16:39:23 +0200 Subject: [PATCH 28/56] Add notification and module status handling to status package --- service/status/module.go | 44 ++++++++++- service/status/notifications.go | 57 +++++++++++++++ service/status/provider.go | 49 ------------- service/status/records.go | 23 ------ service/status/status.go | 125 ++++++++++++++++++++++++++++++++ 5 files changed, 222 insertions(+), 76 deletions(-) create mode 100644 service/status/notifications.go delete mode 100644 service/status/provider.go delete mode 100644 service/status/records.go create mode 100644 service/status/status.go diff --git a/service/status/module.go b/service/status/module.go index 946bd41cd..6d5edfb3e 100644 --- a/service/status/module.go +++ b/service/status/module.go @@ -3,37 +3,67 @@ package status import ( "errors" "fmt" + "sync" "sync/atomic" + "github.com/safing/portmaster/base/notifications" + "github.com/safing/portmaster/base/runtime" "github.com/safing/portmaster/base/utils/debug" "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" ) +// Status Module manages status information. type Status struct { mgr *mgr.Manager instance instance + + publishUpdate runtime.PushFunc + triggerUpdate chan struct{} + + states map[string]mgr.StateUpdate + statesLock sync.Mutex + + notifications map[string]map[string]*notifications.Notification + notificationsLock sync.Mutex } +// Manager returns the module manager. func (s *Status) Manager() *mgr.Manager { return s.mgr } +// Start starts the module. func (s *Status) Start() error { - if err := setupRuntimeProvider(); err != nil { + if err := s.setupRuntimeProvider(); err != nil { return err } + s.mgr.Go("status publisher", s.statusPublisher) + s.instance.NetEnv().EventOnlineStatusChange.AddCallback("update online status in system status", func(_ *mgr.WorkerCtx, _ netenv.OnlineStatus) (bool, error) { - pushSystemStatus() + s.triggerPublishStatus() return false, nil }, ) + // Make an initial status query. + s.statesLock.Lock() + defer s.statesLock.Unlock() + // Add status callback within the lock so we can force the right order. + s.instance.AddStatusCallback("status update", s.handleModuleStatusUpdate) + // Get initial states. + for _, stateUpdate := range s.instance.GetStatus() { + s.mgr.Info("status update", stateUpdate) + s.states[stateUpdate.Module] = stateUpdate + s.deriveNotificationsFromStateUpdate(stateUpdate) + } + return nil } +// Stop stops the module. func (s *Status) Stop() error { return nil } @@ -53,14 +83,18 @@ var ( shimLoaded atomic.Bool ) +// New returns a new status module. func New(instance instance) (*Status, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } m := mgr.New("Status") module = &Status{ - mgr: m, - instance: instance, + mgr: m, + instance: instance, + triggerUpdate: make(chan struct{}, 1), + states: make(map[string]mgr.StateUpdate), + notifications: make(map[string]map[string]*notifications.Notification), } return module, nil @@ -68,4 +102,6 @@ func New(instance instance) (*Status, error) { type instance interface { NetEnv() *netenv.NetEnv + GetStatus() []mgr.StateUpdate + AddStatusCallback(callbackName string, callback mgr.EventCallbackFunc[mgr.StateUpdate]) } diff --git a/service/status/notifications.go b/service/status/notifications.go new file mode 100644 index 000000000..c5208241a --- /dev/null +++ b/service/status/notifications.go @@ -0,0 +1,57 @@ +package status + +import ( + "github.com/safing/portmaster/base/notifications" + "github.com/safing/portmaster/service/mgr" +) + +func (s *Status) deriveNotificationsFromStateUpdate(update mgr.StateUpdate) { + s.notificationsLock.Lock() + defer s.notificationsLock.Unlock() + + notifs := s.notifications[update.Module] + if notifs == nil { + notifs = make(map[string]*notifications.Notification) + s.notifications[update.Module] = notifs + } + + // Add notifications. + seenStateIDs := make(map[string]struct{}, len(update.States)) + for _, state := range update.States { + seenStateIDs[state.ID] = struct{}{} + + _, ok := notifs[state.ID] + if !ok { + n := notifications.Notify(¬ifications.Notification{ + EventID: update.Module + ":" + state.ID, + Type: stateTypeToNotifType(state.Type), + Title: state.Name, + Message: state.Message, + }) + notifs[state.ID] = n + } + } + + // Remove notifications. + for stateID, n := range notifs { + if _, ok := seenStateIDs[stateID]; !ok { + n.Delete() + delete(notifs, stateID) + } + } +} + +func stateTypeToNotifType(stateType mgr.StateType) notifications.Type { + switch stateType { + case mgr.StateTypeUndefined: + return notifications.Info + case mgr.StateTypeHint: + return notifications.Info + case mgr.StateTypeWarning: + return notifications.Warning + case mgr.StateTypeError: + return notifications.Error + default: + return notifications.Info + } +} diff --git a/service/status/provider.go b/service/status/provider.go deleted file mode 100644 index a8707e3b4..000000000 --- a/service/status/provider.go +++ /dev/null @@ -1,49 +0,0 @@ -package status - -import ( - "github.com/safing/portmaster/base/database/record" - "github.com/safing/portmaster/base/runtime" - "github.com/safing/portmaster/service/netenv" -) - -var pushUpdate runtime.PushFunc - -func setupRuntimeProvider() (err error) { - // register the system status getter - statusProvider := runtime.SimpleValueGetterFunc(func(_ string) ([]record.Record, error) { - return []record.Record{buildSystemStatus()}, nil - }) - pushUpdate, err = runtime.Register("system/status", statusProvider) - if err != nil { - return err - } - - return nil -} - -// buildSystemStatus build a new system status record. -func buildSystemStatus() *SystemStatusRecord { - status := &SystemStatusRecord{ - CaptivePortal: netenv.GetCaptivePortal(), - OnlineStatus: netenv.GetOnlineStatus(), - } - - status.CreateMeta() - status.SetKey("runtime:system/status") - - return status -} - -// pushSystemStatus pushes a new system status via -// the runtime database. -func pushSystemStatus() { - if pushUpdate == nil { - return - } - - record := buildSystemStatus() - record.Lock() - defer record.Unlock() - - pushUpdate(record) -} diff --git a/service/status/records.go b/service/status/records.go deleted file mode 100644 index a094bdb5c..000000000 --- a/service/status/records.go +++ /dev/null @@ -1,23 +0,0 @@ -package status - -import ( - "sync" - - "github.com/safing/portmaster/base/database/record" - "github.com/safing/portmaster/service/netenv" -) - -// SystemStatusRecord describes the overall status of the Portmaster. -// It's a read-only record exposed via runtime:system/status. -type SystemStatusRecord struct { - record.Base - sync.Mutex - - // OnlineStatus holds the current online status as - // seen by the netenv package. - OnlineStatus netenv.OnlineStatus - // CaptivePortal holds all information about the captive - // portal of the network the portmaster is currently - // connected to, if any. - CaptivePortal *netenv.CaptivePortal -} diff --git a/service/status/status.go b/service/status/status.go new file mode 100644 index 000000000..d00168dfa --- /dev/null +++ b/service/status/status.go @@ -0,0 +1,125 @@ +package status + +import ( + "slices" + "strings" + "sync" + + "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/runtime" + "github.com/safing/portmaster/service/mgr" + "github.com/safing/portmaster/service/netenv" +) + +// SystemStatusRecord describes the overall status of the Portmaster. +// It's a read-only record exposed via runtime:system/status. +type SystemStatusRecord struct { + record.Base + sync.Mutex + + // OnlineStatus holds the current online status as + // seen by the netenv package. + OnlineStatus netenv.OnlineStatus + // CaptivePortal holds all information about the captive + // portal of the network the portmaster is currently + // connected to, if any. + CaptivePortal *netenv.CaptivePortal + + Modules []mgr.StateUpdate + WorstState struct { + Module string + mgr.State + } +} + +func (s *Status) handleModuleStatusUpdate(_ *mgr.WorkerCtx, update mgr.StateUpdate) (cancel bool, err error) { + s.statesLock.Lock() + defer s.statesLock.Unlock() + + s.mgr.Error("received module state update", "state update", update) + s.states[update.Module] = update + s.deriveNotificationsFromStateUpdate(update) + s.triggerPublishStatus() + + return false, nil +} + +func (s *Status) triggerPublishStatus() { + select { + case s.triggerUpdate <- struct{}{}: + default: + } +} + +func (s *Status) statusPublisher(w *mgr.WorkerCtx) error { + for { + select { + case <-w.Done(): + return nil + case <-s.triggerUpdate: + s.publishSystemStatus() + } + } +} + +func (s *Status) setupRuntimeProvider() (err error) { + // register the system status getter + statusProvider := runtime.SimpleValueGetterFunc(func(_ string) ([]record.Record, error) { + return []record.Record{s.buildSystemStatus()}, nil + }) + s.publishUpdate, err = runtime.Register("system/status", statusProvider) + if err != nil { + return err + } + + return nil +} + +// buildSystemStatus build a new system status record. +func (s *Status) buildSystemStatus() *SystemStatusRecord { + s.statesLock.Lock() + defer s.statesLock.Unlock() + + status := &SystemStatusRecord{ + CaptivePortal: netenv.GetCaptivePortal(), + OnlineStatus: netenv.GetOnlineStatus(), + Modules: make([]mgr.StateUpdate, 0, len(s.states)), + } + for _, v := range s.states { + // Deep copy state. + newStateUpdate := v + newStateUpdate.States = make([]mgr.State, len(v.States)) + copy(newStateUpdate.States, v.States) + status.Modules = append(status.Modules, newStateUpdate) + + // Check if state is worst so far. + for _, state := range newStateUpdate.States { + if state.Type.Severity() > status.WorstState.Type.Severity() { + s.mgr.Error("new worst state", "state", state) + status.WorstState.State = state + status.WorstState.Module = newStateUpdate.Module + } + } + } + slices.SortFunc(status.Modules, func(a, b mgr.StateUpdate) int { + return strings.Compare(a.Module, b.Module) + }) + + status.CreateMeta() + status.SetKey("runtime:system/status") + return status +} + +// publishSystemStatus pushes a new system status via +// the runtime database. +func (s *Status) publishSystemStatus() { + if s.publishUpdate == nil { + return + } + + record := s.buildSystemStatus() + record.Lock() + defer record.Unlock() + + s.publishUpdate(record) +} From 51ec5cd5eeb35d6476d77abd4efbddc8f38af1e7 Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 16 Jul 2024 16:39:33 +0200 Subject: [PATCH 29/56] Fix starting issues --- base/database/dbmodule/db.go | 10 +++++----- base/metrics/module.go | 2 +- service/core/base/module.go | 10 +++++----- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/base/database/dbmodule/db.go b/base/database/dbmodule/db.go index c29207167..99991c88d 100644 --- a/base/database/dbmodule/db.go +++ b/base/database/dbmodule/db.go @@ -46,11 +46,6 @@ func prep() error { } func start() error { - err := database.Initialize(databaseStructureRoot) - if err != nil { - return err - } - startMaintenanceTasks() return nil } @@ -78,6 +73,11 @@ func New(instance instance) (*DBModule, error) { instance: instance, } + err := database.Initialize(databaseStructureRoot) + if err != nil { + return nil, err + } + return module, nil } diff --git a/base/metrics/module.go b/base/metrics/module.go index ccabb4236..c9315f9d5 100644 --- a/base/metrics/module.go +++ b/base/metrics/module.go @@ -86,7 +86,7 @@ func start() error { } if pushOption() != "" { - _ = module.mgr.Do("metric pusher", metricsWriter) + module.mgr.Go("metric pusher", metricsWriter) } return nil diff --git a/service/core/base/module.go b/service/core/base/module.go index 57df37322..848ea97e8 100644 --- a/service/core/base/module.go +++ b/service/core/base/module.go @@ -21,11 +21,6 @@ func (b *Base) Manager() *mgr.Manager { func (b *Base) Start() error { startProfiling() - - if err := registerDatabases(); err != nil { - return err - } - registerLogCleaner() return nil @@ -50,6 +45,11 @@ func New(instance instance) (*Base, error) { mgr: m, instance: instance, } + + if err := registerDatabases(); err != nil { + return nil, err + } + return module, nil } From 8d8ce0b303ec8ea2d0b621b887e47a9399208c2c Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 16 Jul 2024 16:40:19 +0200 Subject: [PATCH 30/56] Remove pilot widget and update security lock to new status data --- desktop/angular/src/app/app.module.ts | 2 - .../src/app/layout/side-dash/side-dash.html | 2 +- .../angular/src/app/services/status.types.ts | 61 ++--- .../src/app/shared/config/subsystems.ts | 241 +++--------------- .../app/shared/security-lock/security-lock.ts | 41 +-- .../src/app/shared/status-pilot/index.ts | 1 - .../app/shared/status-pilot/pilot-widget.html | 57 ----- .../app/shared/status-pilot/pilot-widget.scss | 208 --------------- .../app/shared/status-pilot/pilot-widget.ts | 115 --------- 9 files changed, 65 insertions(+), 663 deletions(-) delete mode 100644 desktop/angular/src/app/shared/status-pilot/index.ts delete mode 100644 desktop/angular/src/app/shared/status-pilot/pilot-widget.html delete mode 100644 desktop/angular/src/app/shared/status-pilot/pilot-widget.scss delete mode 100644 desktop/angular/src/app/shared/status-pilot/pilot-widget.ts diff --git a/desktop/angular/src/app/app.module.ts b/desktop/angular/src/app/app.module.ts index c90aaec5f..187022d82 100644 --- a/desktop/angular/src/app/app.module.ts +++ b/desktop/angular/src/app/app.module.ts @@ -64,7 +64,6 @@ import { SecurityLockComponent } from './shared/security-lock'; import { SPNAccountDetailsComponent } from './shared/spn-account-details'; import { SPNLoginComponent } from './shared/spn-login'; import { SPNStatusComponent } from './shared/spn-status'; -import { PilotWidgetComponent } from './shared/status-pilot'; import { PlaceholderComponent } from './shared/text-placeholder'; import { DashboardWidgetComponent } from './pages/dashboard/dashboard-widget/dashboard-widget.component'; import { MergeProfileDialogComponent } from './pages/app-view/merge-profile-dialog/merge-profile-dialog.component'; @@ -133,7 +132,6 @@ const localeConfig = { MonitorPageComponent, SideDashComponent, NavigationComponent, - PilotWidgetComponent, NotificationListComponent, PromptListComponent, FuzzySearchPipe, diff --git a/desktop/angular/src/app/layout/side-dash/side-dash.html b/desktop/angular/src/app/layout/side-dash/side-dash.html index 81e8b15f4..a4f8e073e 100644 --- a/desktop/angular/src/app/layout/side-dash/side-dash.html +++ b/desktop/angular/src/app/layout/side-dash/side-dash.html @@ -1,5 +1,5 @@
- +
diff --git a/desktop/angular/src/app/services/status.types.ts b/desktop/angular/src/app/services/status.types.ts index f51883667..52013c907 100644 --- a/desktop/angular/src/app/services/status.types.ts +++ b/desktop/angular/src/app/services/status.types.ts @@ -6,22 +6,6 @@ export interface CaptivePortal { Domain: string; } -export enum ModuleStatus { - Off = 0, - Error = 1, - Warning = 2, - Operational = 3 -} - -/** - * Returns a string represetnation of the module status. - * - * @param stat The module status to translate - */ -export function getModuleStatusString(stat: ModuleStatus): string { - return getEnumKey(ModuleStatus, stat) -} - export enum OnlineStatus { Unknown = 0, Offline = 1, @@ -40,55 +24,46 @@ export function getOnlineStatusString(stat: OnlineStatus): string { return getEnumKey(OnlineStatus, stat) } -export interface Threat { - ID: string; - Name: string; - Description: string; - AdditionalData: T; - MitigationLevel: SecurityLevel; - Started: number; - Ended: number; -} - export interface CoreStatus extends Record { - ActiveSecurityLevel: SecurityLevel; - SelectedSecurityLevel: SecurityLevel; - ThreatMitigationLevel: SecurityLevel; OnlineStatus: OnlineStatus; - Threats: Threat[]; CaptivePortal: CaptivePortal; + // Modules: []ModuleState; // TODO: Do we need all modules? + WorstState: { + Module: string, + ID: string, + Name: string, + Message: string, + Type: ModuleStateType, + // Time: time.Time, // TODO: How do we best use Go's time.Time? + Data: any + } } -export enum FailureStatus { - Operational = 0, - Hint = 1, - Warning = 2, - Error = 3 +export enum ModuleStateType { + Undefined = "", + Hint = "hint", + Warning = "warning", + Error = "error" } /** * Returns a string representation of a failure status value. * - * @param stat The failure status value. + * @param stateType The module state type value. */ -export function getFailureStatusString(stat: FailureStatus): string { - return getEnumKey(FailureStatus, stat) +export function getModuleStateString(stateType: ModuleStateType): string { + return getEnumKey(ModuleStateType, stateType) } export interface Module { Enabled: boolean; - FailureID: string; - FailureMsg: string; - FailureStatus: FailureStatus; Name: string; - Status: ModuleStatus; } export interface Subsystem extends Record { ConfigKeySpace: string; Description: string; ExpertiseLevel: string; - FailureStatus: FailureStatus; ID: string; Modules: Module[]; Name: string; diff --git a/desktop/angular/src/app/shared/config/subsystems.ts b/desktop/angular/src/app/shared/config/subsystems.ts index 2071c8074..97bdfab18 100644 --- a/desktop/angular/src/app/shared/config/subsystems.ts +++ b/desktop/angular/src/app/shared/config/subsystems.ts @@ -1,5 +1,5 @@ import { ExpertiseLevelNumber } from "@safing/portmaster-api"; -import { ModuleStatus, Subsystem } from "src/app/services/status.types"; +import { Subsystem } from "src/app/services/status.types"; export interface SubsystemWithExpertise extends Subsystem { minimumExpertise: ExpertiseLevelNumber; @@ -18,70 +18,37 @@ export var subsystems : SubsystemWithExpertise[] = [ Modules: [ { Name: "core", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true }, { Name: "subsystems", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true }, { Name: "runtime", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true }, { Name: "status", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true }, { Name: "ui", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true }, { Name: "compat", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true }, { Name: "broadcasts", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true }, { Name: "sync", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true } ], - FailureStatus: 0, ToggleOptionKey: "", ExpertiseLevel: "user", ReleaseLevel: 0, @@ -104,22 +71,13 @@ export var subsystems : SubsystemWithExpertise[] = [ Modules: [ { Name: "nameserver", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true }, { Name: "resolver", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true } ], - FailureStatus: 0, ToggleOptionKey: "", ExpertiseLevel: "user", ReleaseLevel: 0, @@ -142,150 +100,77 @@ export var subsystems : SubsystemWithExpertise[] = [ Modules: [ { Name: "filter", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true }, { Name: "interception", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true }, { Name: "base", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true }, { Name: "database", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true }, { Name: "config", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true }, { Name: "rng", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true }, { Name: "metrics", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true }, { Name: "api", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true }, { Name: "updates", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true }, { Name: "network", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true }, { Name: "netenv", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true }, { Name: "processes", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true }, { Name: "profiles", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true }, { Name: "notifications", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true }, { Name: "intel", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true }, { Name: "geoip", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true }, { Name: "filterlists", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true }, { Name: "customlists", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true } ], - FailureStatus: 0, ToggleOptionKey: "", ExpertiseLevel: "user", ReleaseLevel: 0, @@ -308,14 +193,9 @@ export var subsystems : SubsystemWithExpertise[] = [ Modules: [ { Name: "netquery", - Enabled: true, - Status: ModuleStatus.Operational, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: true } ], - FailureStatus: 0, ToggleOptionKey: "", ExpertiseLevel: "user", ReleaseLevel: 0, @@ -338,86 +218,45 @@ export var subsystems : SubsystemWithExpertise[] = [ Modules: [ { Name: "captain", - Enabled: false, - Status: 2, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: false }, { Name: "terminal", - Enabled: false, - Status: 2, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: false }, { Name: "cabin", - Enabled: false, - Status: 2, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: false }, { Name: "ships", - Enabled: false, - Status: 2, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: false }, { Name: "docks", - Enabled: false, - Status: 2, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: false }, { Name: "access", - Enabled: false, - Status: 2, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: false }, { Name: "crew", - Enabled: false, - Status: 2, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: false }, { Name: "navigator", - Enabled: false, - Status: 2, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: false }, { Name: "sluice", - Enabled: false, - Status: 2, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: false }, { Name: "patrol", - Enabled: false, - Status: 2, - FailureStatus: 0, - FailureID: "", - FailureMsg: "" + Enabled: false } ], - FailureStatus: 0, ToggleOptionKey: "spn/enable", ExpertiseLevel: "user", ReleaseLevel: 0, diff --git a/desktop/angular/src/app/shared/security-lock/security-lock.ts b/desktop/angular/src/app/shared/security-lock/security-lock.ts index 7b6922c39..1485cc119 100644 --- a/desktop/angular/src/app/shared/security-lock/security-lock.ts +++ b/desktop/angular/src/app/shared/security-lock/security-lock.ts @@ -1,7 +1,7 @@ import { ChangeDetectionStrategy, ChangeDetectorRef, Component, DestroyRef, Input, OnInit, inject } from "@angular/core"; import { SecurityLevel } from "@safing/portmaster-api"; import { combineLatest } from "rxjs"; -import { FailureStatus, StatusService, Subsystem } from "src/app/services"; +import { StatusService, ModuleStateType } from "src/app/services"; import { fadeInAnimation, fadeOutAnimation } from "../animations"; interface SecurityOption { @@ -36,14 +36,7 @@ export class SecurityLockComponent implements OnInit { ) { } ngOnInit(): void { - combineLatest([ - this.statusService.status$, - this.statusService.watchSubsystems() - ]) - .subscribe(([status, subsystems]) => { - const activeLevel = status.ActiveSecurityLevel; - const suggestedLevel = status.ThreatMitigationLevel; - + this.statusService.status$.subscribe(status => { // By default the lock is green and we are "Secure" this.lockLevel = { level: SecurityLevel.Normal, @@ -51,28 +44,16 @@ export class SecurityLockComponent implements OnInit { displayText: 'Secure', } - // Find the highest failure-status reported by any module - // of any subsystem. - const failureStatus = subsystems.reduce((value: FailureStatus, system: Subsystem) => { - if (system.FailureStatus != 0) { - console.log(system); - } - return system.FailureStatus > value - ? system.FailureStatus - : value; - }, FailureStatus.Operational) - - // update the failure level depending on the highest - // failure status. - switch (failureStatus) { - case FailureStatus.Warning: + // update the shield depending on the worst state. + switch (status.WorstState.Type) { + case ModuleStateType.Warning: this.lockLevel = { level: SecurityLevel.High, class: 'text-yellow-300', displayText: 'Warning' } break; - case FailureStatus.Error: + case ModuleStateType.Error: this.lockLevel = { level: SecurityLevel.Extreme, class: 'text-red-300', @@ -81,16 +62,6 @@ export class SecurityLockComponent implements OnInit { break; } - // if the auto-pilot would suggest a higher (mitigation) level - // we are always Insecure - if (activeLevel < suggestedLevel) { - this.lockLevel = { - level: SecurityLevel.High, - class: 'high', - displayText: 'Insecure' - } - } - this.cdr.markForCheck(); }); } diff --git a/desktop/angular/src/app/shared/status-pilot/index.ts b/desktop/angular/src/app/shared/status-pilot/index.ts deleted file mode 100644 index 1ec75e5b3..000000000 --- a/desktop/angular/src/app/shared/status-pilot/index.ts +++ /dev/null @@ -1 +0,0 @@ -export { StatusPilotComponent as PilotWidgetComponent } from "./pilot-widget"; diff --git a/desktop/angular/src/app/shared/status-pilot/pilot-widget.html b/desktop/angular/src/app/shared/status-pilot/pilot-widget.html deleted file mode 100644 index 52e41fbba..000000000 --- a/desktop/angular/src/app/shared/status-pilot/pilot-widget.html +++ /dev/null @@ -1,57 +0,0 @@ - - - -
- {{ activeLevelText }} - - - - - -
- - -
-
- - - - - - Auto Detect - - - - - - Manual - - - -
- - -
-
- - {{opt.displayText}} - - - {{opt.subText || ''}} - - -
-
-
-
-
diff --git a/desktop/angular/src/app/shared/status-pilot/pilot-widget.scss b/desktop/angular/src/app/shared/status-pilot/pilot-widget.scss deleted file mode 100644 index 3f1bcae72..000000000 --- a/desktop/angular/src/app/shared/status-pilot/pilot-widget.scss +++ /dev/null @@ -1,208 +0,0 @@ -:host { - overflow: visible; - position: relative; - display: flex; - justify-content: space-between; - background: none; - user-select: none; - align-items: center; - justify-content: space-evenly; - flex-direction: column; - - - @keyframes shield-pulse { - 0% { - transform: scale(.62); - opacity: 1; - } - - 100% { - transform: scale(1.1); - opacity: 0; - } - } - - @keyframes pulse-opacity { - 0% { - opacity: 0.1; - } - - 100% { - opacity: 1; - } - } -} - -.spn-status { - background-color: var(--info-blue); - border-radius: 100%; - display: flex; - align-items: center; - justify-content: center; - opacity: 1 !important; - padding: 0.2rem; - transform: scale(0.8); - position: absolute; - bottom: 42px; - right: 18px; - - &.connected { - background-color: theme('colors.info.blue'); - } - - &.connecting, - &.failed { - background-color: theme('colors.info.gray'); - } - - svg { - stroke: white; - } -} - -::ng-deep { - - .network-rating-level-list { - @apply p-3 rounded; - - flex-grow: 1; - - label { - opacity: 0.6; - font-size: 0.75rem; - font-weight: 500; - } - - div.rate-header { - display: flex; - justify-content: space-between; - align-items: center; - padding: 0 0 0.3rem 0; - margin-right: 0.11rem; - - .auto-detect { - height: 5px; - width: 5px; - margin-right: 10px; - margin-bottom: 1px; - background-color: #4995f3; - border-radius: 50%; - display: inline-block; - } - } - - &:not(.auto-pilot) { - div.level.selected { - div { - background-color: #292929; - } - - &:after { - transition: none; - opacity: 0 !important; - } - } - } - - div.level { - position: relative; - padding: 2px; - margin-top: 0.155rem; - cursor: pointer; - overflow: hidden; - z-index: 1; - - fa-icon[icon*="question-circle"] { - float: right; - } - - &:after { - transition: all cubic-bezier(0.19, 1, 0.82, 1) .2s; - @apply rounded; - content: ""; - filter: saturate(1.3); - background-image: linear-gradient(90deg, #226ab79f 0%, rgba(2, 0, 36, 0) 45%); - transform: translateX(100%); - position: absolute; - top: 0; - left: 0; - right: 0; - bottom: 0; - z-index: -1; - opacity: 0; - } - - div { - background-color: #202020; - border-radius: 2px; - padding: 9px 17px 10px 18px; - display: block; - opacity: 0.55; - - span { - font-size: 0.725rem; - font-weight: 400; - } - - .situation { - @apply text-tertiary; - @apply ml-2; - font-size: 0.6rem; - font-weight: 600; - } - - svg.help { - width: 0.95rem; - float: right; - padding: 0; - margin: 0; - margin-top: 1.5px; - - .inner { - stroke: var(--text-secondary); - } - - &:hover, - &:active { - .inner { - stroke: var(--text-primary); - } - } - } - } - - &.selected { - div { - background-color: #292929; - opacity: 1; - } - } - - &.selected, - &.suggested { - &:after { - transform: translateX(0%); - opacity: 1; - } - - } - - &.suggested { - &:after { - animation: pulse-opacity 1s ease-in-out infinite alternate; - } - } - - &:hover, - &:active { - div { - opacity: 1; - - span { - opacity: 1; - } - } - } - } - } -} diff --git a/desktop/angular/src/app/shared/status-pilot/pilot-widget.ts b/desktop/angular/src/app/shared/status-pilot/pilot-widget.ts deleted file mode 100644 index 4fa01dd64..000000000 --- a/desktop/angular/src/app/shared/status-pilot/pilot-widget.ts +++ /dev/null @@ -1,115 +0,0 @@ -import { ChangeDetectionStrategy, ChangeDetectorRef, Component, OnInit } from '@angular/core'; -import { ConfigService, SecurityLevel } from '@safing/portmaster-api'; -import { combineLatest } from 'rxjs'; -import { FailureStatus, StatusService, Subsystem } from 'src/app/services'; - -interface SecurityOption { - level: SecurityLevel; - displayText: string; - class: string; - subText?: string; -} - -@Component({ - selector: 'app-status-pilot', - templateUrl: './pilot-widget.html', - styleUrls: [ - './pilot-widget.scss' - ], - changeDetection: ChangeDetectionStrategy.OnPush, -}) -export class StatusPilotComponent implements OnInit { - activeLevel: SecurityLevel = SecurityLevel.Off; - selectedLevel: SecurityLevel = SecurityLevel.Off; - suggestedLevel: SecurityLevel = SecurityLevel.Off; - activeOption: SecurityOption | null = null; - selectedOption: SecurityOption | null = null; - - mode: 'auto' | 'manual' = 'auto'; - - get activeLevelText() { - return this.options.find(opt => opt.level === this.activeLevel)?.displayText || ''; - } - - readonly options: SecurityOption[] = [ - { - level: SecurityLevel.Normal, - displayText: 'Trusted', - class: 'low', - subText: 'Home Network' - }, - { - level: SecurityLevel.High, - displayText: 'Untrusted', - class: 'medium', - subText: 'Public Network' - }, - { - level: SecurityLevel.Extreme, - displayText: 'Danger', - class: 'high', - subText: 'Hacked Network' - }, - ]; - - get networkRatingEnabled$() { return this.configService.networkRatingEnabled$ } - - constructor( - private statusService: StatusService, - private changeDetectorRef: ChangeDetectorRef, - private configService: ConfigService, - ) { } - - ngOnInit() { - - combineLatest([ - this.statusService.status$, - this.statusService.watchSubsystems() - ]) - .subscribe(([status, subsystems]) => { - this.activeLevel = status.ActiveSecurityLevel; - this.selectedLevel = status.SelectedSecurityLevel; - this.suggestedLevel = status.ThreatMitigationLevel; - - if (this.selectedLevel === SecurityLevel.Off) { - this.mode = 'auto'; - } else { - this.mode = 'manual'; - } - - this.selectedOption = this.options.find(opt => opt.level === this.selectedLevel) || null; - this.activeOption = this.options.find(opt => opt.level === this.activeLevel) || null; - - // Find the highest failure-status reported by any module - // of any subsystem. - const failureStatus = subsystems.reduce((value: FailureStatus, system: Subsystem) => { - if (system.FailureStatus != 0) { - console.log(system); - } - return system.FailureStatus > value - ? system.FailureStatus - : value; - }, FailureStatus.Operational) - - this.changeDetectorRef.markForCheck(); - }); - } - - updateMode(mode: 'auto' | 'manual') { - this.mode = mode; - - if (mode === 'auto') { - this.selectLevel(SecurityLevel.Off); - } else { - this.selectLevel(this.activeLevel); - } - } - - selectLevel(level: SecurityLevel) { - if (this.mode === 'auto' && level !== SecurityLevel.Off) { - this.mode = 'manual'; - } - - this.statusService.selectLevel(level).subscribe(); - } -} From e55965aa5a9c99552d8d01c6dd946323cb5b2dcd Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 16 Jul 2024 16:42:25 +0200 Subject: [PATCH 31/56] Remove debug logs --- service/status/module.go | 1 - service/status/status.go | 1 - 2 files changed, 2 deletions(-) diff --git a/service/status/module.go b/service/status/module.go index 6d5edfb3e..1846c4e8f 100644 --- a/service/status/module.go +++ b/service/status/module.go @@ -55,7 +55,6 @@ func (s *Status) Start() error { s.instance.AddStatusCallback("status update", s.handleModuleStatusUpdate) // Get initial states. for _, stateUpdate := range s.instance.GetStatus() { - s.mgr.Info("status update", stateUpdate) s.states[stateUpdate.Module] = stateUpdate s.deriveNotificationsFromStateUpdate(stateUpdate) } diff --git a/service/status/status.go b/service/status/status.go index d00168dfa..7a7ce8896 100644 --- a/service/status/status.go +++ b/service/status/status.go @@ -36,7 +36,6 @@ func (s *Status) handleModuleStatusUpdate(_ *mgr.WorkerCtx, update mgr.StateUpda s.statesLock.Lock() defer s.statesLock.Unlock() - s.mgr.Error("received module state update", "state update", update) s.states[update.Module] = update s.deriveNotificationsFromStateUpdate(update) s.triggerPublishStatus() From 9b54f883406b8d87ee84aa3e738fd367b774b154 Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 16 Jul 2024 17:07:09 +0200 Subject: [PATCH 32/56] Improve http server shutdown --- base/api/router.go | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/base/api/router.go b/base/api/router.go index 4ec77758d..a5fdff957 100644 --- a/base/api/router.go +++ b/base/api/router.go @@ -83,17 +83,20 @@ func stopServer() error { } // Serve starts serving the API endpoint. -func serverManager(_ *mgr.WorkerCtx) error { +func serverManager(ctx *mgr.WorkerCtx) error { // start serving log.Infof("api: starting to listen on %s", server.Addr) backoffDuration := 10 * time.Second for { - // always returns an error - err := module.mgr.Do("http endpoint", func(ctx *mgr.WorkerCtx) error { - return server.ListenAndServe() + err := module.mgr.Do("http server", func(ctx *mgr.WorkerCtx) error { + err := server.ListenAndServe() + // return on shutdown error + if errors.Is(err, http.ErrServerClosed) { + return nil + } + return err }) - // return on shutdown error - if errors.Is(err, http.ErrServerClosed) { + if err == nil { return nil } // log error and restart From 84936ed1d044c04ea4b270f86acaeb8c082beb4b Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 16 Jul 2024 17:07:30 +0200 Subject: [PATCH 33/56] Add workaround for cleanly shutting down firewall+netquery --- service/firewall/module.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/service/firewall/module.go b/service/firewall/module.go index 2be348fb3..1f716cd08 100644 --- a/service/firewall/module.go +++ b/service/firewall/module.go @@ -7,6 +7,7 @@ import ( "path/filepath" "strings" "sync/atomic" + "time" "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/base/log" @@ -57,6 +58,12 @@ func (f *Firewall) Start() error { } func (f *Firewall) Stop() error { + // Cancel all workers and give them a little time. + // The bandwidth updater can crash the sqlite DB for some reason. + // TODO: Investigate. + f.mgr.Cancel() + time.Sleep(100 * time.Millisecond) + return stop() } From b93ed0908368ee8f272b7b6c4908ecf3f9e03f73 Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 16 Jul 2024 17:07:41 +0200 Subject: [PATCH 34/56] Improve logging --- service/mgr/module.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/service/mgr/module.go b/service/mgr/module.go index d56135dd5..45d917949 100644 --- a/service/mgr/module.go +++ b/service/mgr/module.go @@ -104,7 +104,7 @@ func (g *Group) Start() error { m.mgr.Info("starting") startTime := time.Now() - err := m.mgr.Do(m.mgr.name+" Start", func(_ *WorkerCtx) error { + err := m.mgr.Do("start module "+m.mgr.name, func(_ *WorkerCtx) error { return m.module.Start() }) if err != nil { @@ -116,7 +116,7 @@ func (g *Group) Start() error { return fmt.Errorf("failed to start %s: %w", makeModuleName(m.module), err) } duration := time.Since(startTime) - m.mgr.Info("started " + duration.String()) + m.mgr.Info("started", "time", duration.String()) } g.state.Store(groupStateRunning) return nil @@ -142,7 +142,7 @@ func (g *Group) stopFrom(index int) (ok bool) { for i := index; i >= 0; i-- { m := g.modules[i] - err := m.mgr.Do(m.mgr.name+" Stop", func(_ *WorkerCtx) error { + err := m.mgr.Do("stop module "+m.mgr.name, func(_ *WorkerCtx) error { return m.module.Stop() }) if err != nil { From d2c6ab5834f0e1208bc743397181d5955f73428b Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 18 Jul 2024 10:58:56 +0200 Subject: [PATCH 35/56] Add syncing states with notifications for new module system --- base/notifications/module-mirror.go | 164 +++++++++------------------- base/notifications/notification.go | 12 +- service/compat/notify.go | 2 +- service/mgr/states.go | 50 +++++++-- service/nameserver/module.go | 13 +-- service/netenv/main.go | 2 - service/resolver/main.go | 3 +- service/status/notifications.go | 63 +++++++---- service/updates/notify.go | 2 +- spn/captain/client.go | 36 ++---- 10 files changed, 160 insertions(+), 187 deletions(-) diff --git a/base/notifications/module-mirror.go b/base/notifications/module-mirror.go index e1614fa0d..41ef350f8 100644 --- a/base/notifications/module-mirror.go +++ b/base/notifications/module-mirror.go @@ -1,116 +1,58 @@ package notifications import ( -// "github.com/safing/portbase/modules" -// "github.com/safing/portmaster/base/log" -// "github.com/safing/portmaster/service/mgr" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" ) -// AttachToModule attaches the notification to a module and changes to the -// notification will be reflected on the module failure status. -// func (n *Notification) AttachToState(state *mgr.StateMgr) { -// if state == nil { -// log.Warningf("notifications: invalid usage: cannot attach %s to nil module", n.EventID) -// return -// } - -// n.lock.Lock() -// defer n.lock.Unlock() - -// if n.State != Active { -// log.Warningf("notifications: cannot attach module to inactive notification %s", n.EventID) -// return -// } -// if n.belongsTo != nil { -// log.Warningf("notifications: cannot override attached module for notification %s", n.EventID) -// return -// } - -// // Attach module. -// n.belongsTo = state - -// // Set module failure status. -// switch n.Type { //nolint:exhaustive -// case Info: -// m.Hint(n.EventID, n.Title, n.Message) -// case Warning: -// m.Warning(n.EventID, n.Title, n.Message) -// case Error: -// m.Error(n.EventID, n.Title, n.Message) -// default: -// log.Warningf("notifications: incompatible type for attaching to module in notification %s", n.EventID) -// m.Error(n.EventID, n.Title, n.Message+" [incompatible notification type]") -// } -// } - -// // resolveModuleFailure removes the notification from the module failure status. -// func (n *Notification) resolveModuleFailure() { -// if n.belongsTo != nil { -// // Resolve failure in attached module. -// n.belongsTo.Resolve(n.EventID) - -// // Reset attachment in order to mitigate duplicate failure resolving. -// // Re-attachment is prevented by the state check when attaching. -// n.belongsTo = nil -// } -// } - -// func init() { -// modules.SetFailureUpdateNotifyFunc(mirrorModuleStatus) -// } - -// func mirrorModuleStatus(moduleFailure uint8, id, title, msg string) { -// // Ignore "resolve all" requests. -// if id == "" { -// return -// } - -// // Get notification from storage. -// n, ok := getNotification(id) -// if ok { -// // The notification already exists. - -// // Check if we should delete it. -// if moduleFailure == modules.FailureNone && !n.Meta().IsDeleted() { - -// // Remove belongsTo, as the deletion was already triggered by the module itself. -// n.Lock() -// n.belongsTo = nil -// n.Unlock() - -// n.Delete() -// } - -// return -// } - -// // A notification for the given ID does not yet exists, create it. -// n = &Notification{ -// EventID: id, -// Title: title, -// Message: msg, -// AvailableActions: []*Action{ -// { -// Text: "Get Help", -// Type: ActionTypeOpenURL, -// Payload: "https://safing.io/support/", -// }, -// }, -// } - -// switch moduleFailure { -// case modules.FailureNone: -// return -// case modules.FailureHint: -// n.Type = Info -// n.AvailableActions = nil -// case modules.FailureWarning: -// n.Type = Warning -// n.ShowOnSystem = true -// case modules.FailureError: -// n.Type = Error -// n.ShowOnSystem = true -// } - -// Notify(n) -// } +// SyncWithState syncs the notification to a state in the given state mgr. +// The state will be removed when the notification is removed. +func (n *Notification) SyncWithState(state *mgr.StateMgr) { + if state == nil { + log.Warningf("notifications: invalid usage: cannot attach %s to nil module", n.EventID) + return + } + + n.lock.Lock() + defer n.lock.Unlock() + + if n.Meta().IsDeleted() { + log.Warningf("notifications: cannot attach module to deleted notification %s", n.EventID) + return + } + if n.State != Active { + log.Warningf("notifications: cannot attach module to inactive notification %s", n.EventID) + return + } + if n.belongsTo != nil { + log.Warningf("notifications: cannot override attached module for notification %s", n.EventID) + return + } + + // Attach module. + n.belongsTo = state + + // Create state with same ID. + state.Add(mgr.State{ + ID: n.EventID, + Name: n.Title, + Message: n.Message, + Type: notifTypeToStateType(n.Type), + Data: n.EventData, + }) +} + +func notifTypeToStateType(notifType Type) mgr.StateType { + switch notifType { + case Info: + return mgr.StateTypeHint + case Warning: + return mgr.StateTypeWarning + case Prompt: + return mgr.StateTypeUndefined + case Error: + return mgr.StateTypeError + default: + return mgr.StateTypeUndefined + } +} diff --git a/base/notifications/notification.go b/base/notifications/notification.go index 68526ab17..6db6cef47 100644 --- a/base/notifications/notification.go +++ b/base/notifications/notification.go @@ -101,7 +101,7 @@ type Notification struct { //nolint:maligned // belongsTo holds the state this notification belongs to. The notification // lifecycle will be mirrored to the specified failure status. - // belongsTo *mgr.StateMgr + belongsTo *mgr.StateMgr lock sync.Mutex actionFunction NotificationActionFn // call function to process action @@ -425,6 +425,11 @@ func (n *Notification) delete(pushUpdate bool) { n.lock.Lock() defer n.lock.Unlock() + // Check if notification is already deleted. + if n.Meta().IsDeleted() { + return + } + // Save ID for deletion id = n.EventID @@ -442,7 +447,10 @@ func (n *Notification) delete(pushUpdate bool) { dbController.PushUpdate(n) } - // n.resolveModuleFailure() + // Remove the connected state. + if n.belongsTo != nil { + n.belongsTo.Remove(n.EventID) + } } // Expired notifies the caller when the notification has expired. diff --git a/service/compat/notify.go b/service/compat/notify.go index ce2a949c7..72932aeb3 100644 --- a/service/compat/notify.go +++ b/service/compat/notify.go @@ -137,7 +137,7 @@ func (issue *systemIssue) notify(err error) { notifications.Notify(n) systemIssueNotification = n - // n.AttachToModule(module) + n.SyncWithState(module.states) // Report the raw error as module error. // FIXME(vladimir): Is there a need for this kind of error reporting? diff --git a/service/mgr/states.go b/service/mgr/states.go index 5ecb456a5..a4228522d 100644 --- a/service/mgr/states.go +++ b/service/mgr/states.go @@ -18,12 +18,32 @@ type StateMgr struct { // State describes the state of a manager or module. type State struct { - ID string // Required. - Name string // Required. - Message string // Optional. - Type StateType // Optional. - Time time.Time // Optional, will be set to current time if not set. - Data any // Optional. + // ID is a program-unique ID. + // It must not only be unique within the StateMgr, but for the whole program, + // as it may be re-used with related systems. + // Required. + ID string + + // Name is the name of the state. + // This may also serve as a notification title. + // Required. + Name string + + // Message is a more detailed message about the state. + // Optional. + Message string + + // Type defines the type of the state. + // Optional. + Type StateType + + // Time is the time when the state was created or the originating incident occured. + // Optional, will be set to current time if not set. + Time time.Time + + // Data can hold any additional data necessary for further processing of connected systems. + // Optional. + Data any } // StateType defines commonly used states. @@ -59,6 +79,7 @@ type StateUpdate struct { States []State } +// StatefulModule is used for interface checks on modules. type StatefulModule interface { States() *StateMgr } @@ -87,10 +108,10 @@ func (m *StateMgr) Add(s State) { } // Update or add state. - index := slices.IndexFunc[[]State, State](m.states, func(es State) bool { + index := slices.IndexFunc(m.states, func(es State) bool { return es.ID == s.ID }) - if index > 0 { + if index >= 0 { m.states[index] = s } else { m.states = append(m.states, s) @@ -104,11 +125,18 @@ func (m *StateMgr) Remove(id string) { m.statesLock.Lock() defer m.statesLock.Unlock() - slices.DeleteFunc[[]State, State](m.states, func(s State) bool { - return s.ID == id + var entryRemoved bool + m.states = slices.DeleteFunc(m.states, func(s State) bool { + if s.ID == id { + entryRemoved = true + return true + } + return false }) - m.statesEventMgr.Submit(m.export()) + if entryRemoved { + m.statesEventMgr.Submit(m.export()) + } } // Clear removes all states. diff --git a/service/nameserver/module.go b/service/nameserver/module.go index f2f3e4305..69fad7f28 100644 --- a/service/nameserver/module.go +++ b/service/nameserver/module.go @@ -179,7 +179,7 @@ func startListener(ip net.IP, port uint16, first bool) { } func handleListenError(err error, ip net.IP, port uint16, primaryListener bool) { - // var n *notifications.Notification + var n *notifications.Notification // Create suffix for secondary listener var secondaryEventIDSuffix string @@ -208,7 +208,7 @@ func handleListenError(err error, ip net.IP, port uint16, primaryListener bool) } // Notify user about conflicting service. - _ = notifications.Notify(¬ifications.Notification{ + n = notifications.Notify(¬ifications.Notification{ EventID: eventIDConflictingService + secondaryEventIDSuffix, Type: notifications.Error, Title: "Conflicting DNS Software", @@ -225,7 +225,7 @@ func handleListenError(err error, ip net.IP, port uint16, primaryListener bool) }) } else { // If no conflict is found, report the error directly. - _ = notifications.Notify(¬ifications.Notification{ + n = notifications.Notify(¬ifications.Notification{ EventID: eventIDListenerFailed + secondaryEventIDSuffix, Type: notifications.Error, Title: "Secure DNS Error", @@ -238,10 +238,9 @@ func handleListenError(err error, ip net.IP, port uint16, primaryListener bool) } // Attach error to module, if primary listener. - // TODO(vladimir): is this needed? - // if primaryListener { - // n.AttachToModule(module) - // } + if primaryListener { + n.SyncWithState(module.states) + } } func stop() error { diff --git a/service/netenv/main.go b/service/netenv/main.go index 608b1a066..e1a681500 100644 --- a/service/netenv/main.go +++ b/service/netenv/main.go @@ -48,7 +48,6 @@ func (ne *NetEnv) Stop() error { } func prep() error { - // FIXME(vladimir): Does this need to be in the prep function. checkForIPv6Stack() if err := registerAPIEndpoints(); err != nil { @@ -59,7 +58,6 @@ func prep() error { return err } - // FIXME(vladimir): Does this need to be in the prep function. return prepLocation() } diff --git a/service/resolver/main.go b/service/resolver/main.go index 46c26ab02..69b7c51db 100644 --- a/service/resolver/main.go +++ b/service/resolver/main.go @@ -188,8 +188,7 @@ This notification will go away when Portmaster detects a working configured DNS notifications.Notify(n) failingResolverNotification = n - // TODO(vladimir): is this needed? - // n.AttachToModule(module) + n.SyncWithState(module.states) } func resetFailingResolversNotification() { diff --git a/service/status/notifications.go b/service/status/notifications.go index c5208241a..0e7deb896 100644 --- a/service/status/notifications.go +++ b/service/status/notifications.go @@ -20,16 +20,48 @@ func (s *Status) deriveNotificationsFromStateUpdate(update mgr.StateUpdate) { for _, state := range update.States { seenStateIDs[state.ID] = struct{}{} - _, ok := notifs[state.ID] - if !ok { - n := notifications.Notify(¬ifications.Notification{ - EventID: update.Module + ":" + state.ID, - Type: stateTypeToNotifType(state.Type), - Title: state.Name, - Message: state.Message, - }) + // Check if we already have a notification registered. + if _, ok := notifs[state.ID]; ok { + continue + } + + // Check if the notification was pre-created. + // If a matching notification is found, assign it. + n := notifications.Get(state.ID) + if n != nil { notifs[state.ID] = n + continue + } + + // Create a new notification. + n = ¬ifications.Notification{ + EventID: state.ID, + Title: state.Name, + Message: state.Message, + AvailableActions: []*notifications.Action{ + { + Text: "Get Help", + Type: notifications.ActionTypeOpenURL, + Payload: "https://safing.io/support/", + }, + }, + } + switch state.Type { + case mgr.StateTypeWarning: + n.Type = notifications.Warning + n.ShowOnSystem = true + case mgr.StateTypeError: + n.Type = notifications.Error + n.ShowOnSystem = true + case mgr.StateTypeHint, mgr.StateTypeUndefined: + fallthrough + default: + n.Type = notifications.Info + n.AvailableActions = nil } + + notifs[state.ID] = n + notifications.Notify(n) } // Remove notifications. @@ -40,18 +72,3 @@ func (s *Status) deriveNotificationsFromStateUpdate(update mgr.StateUpdate) { } } } - -func stateTypeToNotifType(stateType mgr.StateType) notifications.Type { - switch stateType { - case mgr.StateTypeUndefined: - return notifications.Info - case mgr.StateTypeHint: - return notifications.Info - case mgr.StateTypeWarning: - return notifications.Warning - case mgr.StateTypeError: - return notifications.Error - default: - return notifications.Info - } -} diff --git a/service/updates/notify.go b/service/updates/notify.go index 0cd97bfde..30b2bd32b 100644 --- a/service/updates/notify.go +++ b/service/updates/notify.go @@ -164,5 +164,5 @@ func notifyUpdateCheckFailed(force bool, err error) { ResultAction: "display", }, }, - ) // FIXME: add replacement for this .AttachToModule(module) + ).SyncWithState(module.states) } diff --git a/spn/captain/client.go b/spn/captain/client.go index 49bcf7828..4a2546004 100644 --- a/spn/captain/client.go +++ b/spn/captain/client.go @@ -228,9 +228,7 @@ func clientCheckAccountAndTokens(ctx context.Context) clientComponentResult { `Please restart Portmaster.`, // TODO: Add restart button. // TODO: Use special UI restart action in order to reload UI on restart. - ) - // TODO(vladimir): this is not needed right - // .AttachToModule(module) + ).SyncWithState(module.states) resetSPNStatus(StatusFailed, true) log.Errorf("spn/captain: client internal error: %s", err) return clientResultReconnect @@ -243,9 +241,7 @@ func clientCheckAccountAndTokens(ctx context.Context) clientComponentResult { "SPN Login Required", `Please log in to access the SPN.`, spnLoginButton, - ) - // TODO(vladimir): this is not needed right - // .AttachToModule(module) + ).SyncWithState(module.states) resetSPNStatus(StatusFailed, true) log.Warningf("spn/captain: enabled but not logged in") return clientResultReconnect @@ -263,9 +259,7 @@ func clientCheckAccountAndTokens(ctx context.Context) clientComponentResult { "spn:failed-to-update-user", "SPN Account Server Error", fmt.Sprintf(`The status of your SPN account could not be updated: %s`, err), - ) - // TODO(vladimir): this is not needed right - // .AttachToModule(module) + ).SyncWithState(module.states) resetSPNStatus(StatusFailed, true) log.Errorf("spn/captain: failed to update ineligible account: %s", err) return clientResultReconnect @@ -282,9 +276,7 @@ func clientCheckAccountAndTokens(ctx context.Context) clientComponentResult { "SPN Not Included In Package", "Your current Portmaster Package does not include access to the SPN. Please upgrade your package on the Account Page.", spnOpenAccountPage, - ) - // TODO(vladimir): this is not needed right - // .AttachToModule(module) + ).SyncWithState(module.states) resetSPNStatus(StatusFailed, true) return clientResultReconnect } @@ -299,9 +291,7 @@ func clientCheckAccountAndTokens(ctx context.Context) clientComponentResult { "Portmaster Package Issue", "Cannot enable SPN: "+message, spnOpenAccountPage, - ) - // TODO(vladimir): this is not needed right - // .AttachToModule(module) + ).SyncWithState(module.states) resetSPNStatus(StatusFailed, true) return clientResultReconnect } @@ -321,9 +311,7 @@ func clientCheckAccountAndTokens(ctx context.Context) clientComponentResult { "spn:tokens-exhausted", "SPN Access Tokens Exhausted", `The Portmaster failed to get new access tokens to access the SPN. The Portmaster will automatically retry to get new access tokens.`, - ) - // TODO(vladimir): this is not needed right - // .AttachToModule(module) + ).SyncWithState(module.states) resetSPNStatus(StatusFailed, false) } return clientResultRetry @@ -371,9 +359,7 @@ func clientConnectToHomeHub(ctx context.Context) clientComponentResult { Key: CfgOptionHomeHubPolicyKey, }, }, - ) - // TODO(vladimir): this is not needed right - // .AttachToModule(module) + ).SyncWithState(module.states) case errors.Is(err, ErrReInitSPNSuggested): notifications.NotifyError( @@ -389,18 +375,14 @@ func clientConnectToHomeHub(ctx context.Context) clientComponentResult { ResultAction: "display", }, }, - ) - // TODO(vladimir): this is not needed right - // .AttachToModule(module) + ).SyncWithState(module.states) default: notifications.NotifyWarn( "spn:home-hub-failure", "SPN Failed to Connect", fmt.Sprintf("Failed to connect to a home hub: %s. The Portmaster will retry to connect automatically.", err), - ) - // TODO(vladimir): this is not needed right - // .AttachToModule(module) + ).SyncWithState(module.states) } return clientResultReconnect From 5f3147edf88604175e21a1e9d316a96bf93824a1 Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 18 Jul 2024 17:17:28 +0200 Subject: [PATCH 36/56] Improve starting, stopping, shutdown; resolve FIXMEs/TODOs --- base/api/endpoints_debug.go | 13 +-- base/api/endpoints_modules.go | 51 ------------ base/api/main.go | 4 - base/api/module.go | 1 + base/api/router.go | 15 +--- base/metrics/module.go | 18 +++-- cmds/portmaster-core/main.go | 91 ++++++++++----------- service/broadcasts/module.go | 7 ++ service/broadcasts/notify.go | 9 +-- service/compat/notify.go | 4 - service/config.go | 4 +- service/core/api.go | 5 +- service/core/core.go | 24 +----- service/instance.go | 114 ++++++++++++++++++++------- service/intel/filterlists/updater.go | 9 +-- service/mgr/{module.go => group.go} | 112 +++++++++++++++----------- service/mgr/group_ext.go | 92 +++++++++++++++++++++ service/mgr/manager.go | 35 +++++--- service/netquery/module_api.go | 7 +- service/network/clean.go | 1 - service/resolver/resolver-tcp.go | 20 ++--- service/status/module.go | 8 +- service/updates/module.go | 7 +- service/updates/restart.go | 4 +- spn/access/client.go | 8 -- spn/access/module.go | 32 ++++++-- spn/captain/api.go | 68 ++++++++-------- spn/captain/client.go | 33 ++++---- spn/captain/module.go | 17 ++-- spn/captain/navigation.go | 14 ++-- spn/navigator/update.go | 2 +- 31 files changed, 461 insertions(+), 368 deletions(-) delete mode 100644 base/api/endpoints_modules.go rename service/mgr/{module.go => group.go} (63%) create mode 100644 service/mgr/group_ext.go diff --git a/base/api/endpoints_debug.go b/base/api/endpoints_debug.go index da1397b6a..d2db39601 100644 --- a/base/api/endpoints_debug.go +++ b/base/api/endpoints_debug.go @@ -3,6 +3,7 @@ package api import ( "bytes" "context" + "errors" "fmt" "net/http" "os" @@ -128,20 +129,14 @@ You can easily view this data in your browser with this command (with Go install // ping responds with pong. func ping(ar *Request) (msg string, err error) { - // TODO: Remove upgrade to "ready" when all UI components have transitioned. - // if modules.IsStarting() || modules.IsShuttingDown() { - // return "", ErrorWithStatus(errors.New("portmaster is not ready, reload (F5) to try again"), http.StatusTooEarly) - // } - return "Pong.", nil } // ready checks if Portmaster has completed starting. func ready(ar *Request) (msg string, err error) { - // TODO(vladimir): provide alternative for this. Instance state? - // if modules.IsStarting() || modules.IsShuttingDown() { - // return "", ErrorWithStatus(errors.New("portmaster is not ready, reload (F5) to try again"), http.StatusTooEarly) - // } + if module.instance.Ready() { + return "", ErrorWithStatus(errors.New("portmaster is not ready, reload (F5) to try again"), http.StatusTooEarly) + } return "Portmaster is ready.", nil } diff --git a/base/api/endpoints_modules.go b/base/api/endpoints_modules.go deleted file mode 100644 index 7be727956..000000000 --- a/base/api/endpoints_modules.go +++ /dev/null @@ -1,51 +0,0 @@ -package api - -func registerModulesEndpoints() error { - // TODO(vladimir): do we need this? - // if err := RegisterEndpoint(Endpoint{ - // Path: "modules/status", - // Read: PermitUser, - // StructFunc: getStatusfunc, - // Name: "Get Module Status", - // Description: "Returns status information of all modules.", - // }); err != nil { - // return err - // } - - // TODO(vladimir): do we need this? - // if err := RegisterEndpoint(Endpoint{ - // Path: "modules/{moduleName:.+}/trigger/{eventName:.+}", - // Write: PermitSelf, - // ActionFunc: triggerEvent, - // Name: "Trigger Event", - // Description: "Triggers an event of an internal module.", - // }); err != nil { - // return err - // } - - return nil -} - -// func getStatusfunc(ar *Request) (i interface{}, err error) { -// status := modules.GetStatus() -// if status == nil { -// return nil, errors.New("modules not yet initialized") -// } -// return status, nil -// } - -// func triggerEvent(ar *Request) (msg string, err error) { -// // Get parameters. -// moduleName := ar.URLVars["moduleName"] -// eventName := ar.URLVars["eventName"] -// if moduleName == "" || eventName == "" { -// return "", errors.New("invalid parameters") -// } - -// // Inject event. -// if err := module.InjectEvent("api event injection", moduleName, eventName, nil); err != nil { -// return "", fmt.Errorf("failed to inject event: %w", err) -// } - -// return "event successfully injected", nil -// } diff --git a/base/api/main.go b/base/api/main.go index 83c97c0f4..06875a399 100644 --- a/base/api/main.go +++ b/base/api/main.go @@ -45,10 +45,6 @@ func prep() error { return err } - if err := registerModulesEndpoints(); err != nil { - return err - } - return registerMetaEndpoints() } diff --git a/base/api/module.go b/base/api/module.go index ba4a038f0..65f01b629 100644 --- a/base/api/module.go +++ b/base/api/module.go @@ -60,4 +60,5 @@ func New(instance instance) (*API, error) { type instance interface { Config() *config.Config SetCmdLineOperation(f func() error) + Ready() bool } diff --git a/base/api/router.go b/base/api/router.go index a5fdff957..53c37495f 100644 --- a/base/api/router.go +++ b/base/api/router.go @@ -272,15 +272,6 @@ func (mh *mainHandler) handle(w http.ResponseWriter, r *http.Request) error { return nil } - // Wait for the owning module to be ready. - // TODO(vladimir): no need to check for status anymore right? - // if moduleHandler, ok := handler.(ModuleHandler); ok { - // if !moduleIsReady(moduleHandler.BelongsTo()) { - // http.Error(lrw, "The API endpoint is not ready yet. Reload (F5) to try again.", http.StatusServiceUnavailable) - // return nil - // } - // } - // Check if we have a handler. if handler == nil { http.Error(lrw, "Not found.", http.StatusNotFound) @@ -290,10 +281,8 @@ func (mh *mainHandler) handle(w http.ResponseWriter, r *http.Request) error { // Format panics in handler. defer func() { if panicValue := recover(); panicValue != nil { - // Report failure via module system. - // TODO(vladimir): do we need panic report here - // me := module.NewPanicError("api request", "custom", panicValue) - // me.Report() + // Log failure. + log.Errorf("api: handler panic: %s", panicValue) // Respond with a server error. if devMode() { http.Error( diff --git a/base/metrics/module.go b/base/metrics/module.go index c9315f9d5..a1e5bd378 100644 --- a/base/metrics/module.go +++ b/base/metrics/module.go @@ -42,6 +42,7 @@ var ( registry []Metric registryLock sync.RWMutex + readyToRegister bool firstMetricRegistered bool metricNamespace string globalLabels = make(map[string]string) @@ -69,6 +70,13 @@ func start() error { } } + // Mark registry as ready to register metrics. + func() { + registryLock.Lock() + defer registryLock.Unlock() + readyToRegister = true + }() + if err := registerInfoMetric(); err != nil { return err } @@ -128,14 +136,14 @@ func register(m Metric) error { registry = append(registry, m) sort.Sort(byLabeledID(registry)) + // Check if we can already register. + if !readyToRegister { + return fmt.Errorf("registering metric %q too early", m.ID()) + } + // Set flag that first metric is now registered. firstMetricRegistered = true - // TODO(vladimir): With the new modules system there is no way this can fail. I may be wrong. - // if module.Status() < modules.StatusStarting { - // return fmt.Errorf("registering metric %q too early", m.ID()) - // } - return nil } diff --git a/cmds/portmaster-core/main.go b/cmds/portmaster-core/main.go index ab3d8a0f6..407f67985 100644 --- a/cmds/portmaster-core/main.go +++ b/cmds/portmaster-core/main.go @@ -57,33 +57,30 @@ func main() { return } - // Create - instance, err := service.New("2.0.0", &service.ServiceConfig{ - ShutdownFunc: func(exitCode int) { - fmt.Printf("ExitCode: %d\n", exitCode) - }, - }) + // Create instance. + instance, err := service.New(&service.ServiceConfig{}) if err != nil { fmt.Printf("error creating an instance: %s\n", err) - return + os.Exit(2) } - // execute command if available + // Execute command line operation, if available. if instance.CommandLineOperation != nil { // Run the function and exit. + err = instance.CommandLineOperation() if err != nil { fmt.Fprintf(os.Stderr, "cmdline operation failed: %s\n", err) - os.Exit(1) + os.Exit(3) } os.Exit(0) } // Start go func() { - err = instance.Group.Start() + err = instance.Start() if err != nil { fmt.Printf("instance start failed: %s\n", err) - return + os.Exit(1) } }() @@ -99,50 +96,48 @@ func main() { sigUSR1, ) -signalLoop: - for { - select { - case sig := <-signalCh: - // Only print and continue to wait if SIGUSR1 - if sig == sigUSR1 { - printStackTo(os.Stderr, "PRINTING STACK ON REQUEST") - continue signalLoop - } - + select { + case sig := <-signalCh: + // Only print and continue to wait if SIGUSR1 + if sig == sigUSR1 { + printStackTo(os.Stderr, "PRINTING STACK ON REQUEST") + } else { fmt.Println(" ") // CLI output. slog.Warn("program was interrupted, stopping") + } - // catch signals during shutdown - go func() { - forceCnt := 5 - for { - <-signalCh - forceCnt-- - if forceCnt > 0 { - fmt.Printf(" again, but already shutting down - %d more to force\n", forceCnt) - } else { - printStackTo(os.Stderr, "PRINTING STACK ON FORCED EXIT") - os.Exit(1) - } - } - }() - - go func() { - time.Sleep(3 * time.Minute) - printStackTo(os.Stderr, "PRINTING STACK - TAKING TOO LONG FOR SHUTDOWN") - os.Exit(1) - }() + case <-instance.Stopped(): + os.Exit(instance.ExitCode()) + } - if err := instance.Stop(); err != nil { - slog.Error("failed to stop portmaster", "err", err) - continue signalLoop + // Catch signals during shutdown. + // Rapid unplanned disassembly after 5 interrupts. + go func() { + forceCnt := 5 + for { + <-signalCh + forceCnt-- + if forceCnt > 0 { + fmt.Printf(" again, but already shutting down - %d more to force\n", forceCnt) + } else { + printStackTo(os.Stderr, "PRINTING STACK ON FORCED EXIT") + os.Exit(1) } - break signalLoop - - case <-instance.Done(): - break signalLoop } + }() + + // Rapid unplanned disassembly after 3 minutes. + go func() { + time.Sleep(3 * time.Minute) + printStackTo(os.Stderr, "PRINTING STACK - TAKING TOO LONG FOR SHUTDOWN") + os.Exit(1) + }() + + // Stop instance. + if err := instance.Stop(); err != nil { + slog.Error("failed to stop portmaster", "err", err) } + os.Exit(instance.ExitCode()) } func printStackTo(writer io.Writer, msg string) { diff --git a/service/broadcasts/module.go b/service/broadcasts/module.go index 31ab162f4..a3968933e 100644 --- a/service/broadcasts/module.go +++ b/service/broadcasts/module.go @@ -13,6 +13,8 @@ import ( type Broadcasts struct { mgr *mgr.Manager instance instance + + states *mgr.StateMgr } func (b *Broadcasts) Manager() *mgr.Manager { @@ -27,6 +29,10 @@ func (b *Broadcasts) Stop() error { return nil } +func (b *Broadcasts) States() *mgr.StateMgr { + return b.states +} + var ( db = database.NewInterface(&database.Options{ Local: true, @@ -75,6 +81,7 @@ func New(instance instance) (*Broadcasts, error) { module = &Broadcasts{ mgr: m, instance: instance, + states: m.NewStateMgr(), } if err := prep(); err != nil { diff --git a/service/broadcasts/notify.go b/service/broadcasts/notify.go index 2ab6c9645..a010f2494 100644 --- a/service/broadcasts/notify.go +++ b/service/broadcasts/notify.go @@ -210,12 +210,9 @@ func handleBroadcast(bn *BroadcastNotification, matchingDataAccessor accessor.Ac // Display notification. n.Save() - - // Attach to module to raise more awareness. - // TODO(vladimir): is there a need for this? - // if bn.AttachToModule { - // n.AttachToModule(module) - // } + if bn.AttachToModule { + n.SyncWithState(module.states) + } return nil } diff --git a/service/compat/notify.go b/service/compat/notify.go index 72932aeb3..3cb64bf5b 100644 --- a/service/compat/notify.go +++ b/service/compat/notify.go @@ -138,10 +138,6 @@ func (issue *systemIssue) notify(err error) { systemIssueNotification = n n.SyncWithState(module.states) - - // Report the raw error as module error. - // FIXME(vladimir): Is there a need for this kind of error reporting? - // module.NewErrorMessage("selfcheck", err).Report() } func resetSystemIssue() { diff --git a/service/config.go b/service/config.go index 5c6884348..85a98603f 100644 --- a/service/config.go +++ b/service/config.go @@ -1,5 +1,3 @@ package service -type ServiceConfig struct { - ShutdownFunc func(exitCode int) -} +type ServiceConfig struct{} diff --git a/service/core/api.go b/service/core/api.go index 2c25262f0..ea03ac420 100644 --- a/service/core/api.go +++ b/service/core/api.go @@ -113,10 +113,7 @@ func registerAPIEndpoints() error { func shutdown(_ *api.Request) (msg string, err error) { log.Warning("core: user requested shutdown via action") - // Do not run in worker, as this would block itself here. - // TODO(vladimir): replace with something better - go ShutdownHook() //nolint:errcheck - + module.instance.Shutdown(0) return "shutdown initiated", nil } diff --git a/service/core/core.go b/service/core/core.go index d689b5ff2..983e9eee2 100644 --- a/service/core/core.go +++ b/service/core/core.go @@ -5,7 +5,6 @@ import ( "flag" "fmt" "sync/atomic" - "time" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/metrics" @@ -16,7 +15,6 @@ import ( _ "github.com/safing/portmaster/service/status" _ "github.com/safing/portmaster/service/sync" _ "github.com/safing/portmaster/service/ui" - "github.com/safing/portmaster/service/updates" ) const ( @@ -63,8 +61,6 @@ func init() { false, "disable shutdown event to keep app and notifier open when core shuts down", ) - - // modules.SetGlobalShutdownFn(shutdownHook) } func prep() error { @@ -94,22 +90,6 @@ func start() error { return nil } -func ShutdownHook() { - // Notify everyone of the restart/shutdown. - if !updates.IsRestarting() { - // Only trigger shutdown event if not disabled. - if !disableShutdownEvent { - module.EventShutdown.Submit(struct{}{}) - } - } else { - module.EventRestart.Submit(struct{}{}) - } - - // Wait a bit for the event to propagate. - // TODO(vladimir): is this necessary? - time.Sleep(100 * time.Millisecond) -} - var ( module *Core shimLoaded atomic.Bool @@ -137,4 +117,6 @@ func New(instance instance) (*Core, error) { return module, nil } -type instance interface{} +type instance interface { + Shutdown(exitCode int) +} diff --git a/service/instance.go b/service/instance.go index da266b041..d53cbaec2 100644 --- a/service/instance.go +++ b/service/instance.go @@ -1,7 +1,10 @@ package service import ( + "context" "fmt" + "sync/atomic" + "time" "github.com/safing/portmaster/base/api" "github.com/safing/portmaster/base/config" @@ -45,9 +48,12 @@ import ( // Instance is an instance of a portmaste service. type Instance struct { - version string + ctx context.Context + cancelCtx context.CancelFunc + serviceGroup *mgr.Group + + exitCode atomic.Int32 - *mgr.Group database *dbmodule.DBModule config *config.Config api *api.API @@ -80,7 +86,7 @@ type Instance struct { access *access.Access // SPN modules - SpnGroup *mgr.Group + SpnGroup *mgr.ExtendedGroup cabin *cabin.Cabin navigator *navigator.Navigator captain *captain.Captain @@ -95,11 +101,10 @@ type Instance struct { } // New returns a new portmaster service instance. -func New(version string, svcCfg *ServiceConfig) (*Instance, error) { +func New(svcCfg *ServiceConfig) (*Instance, error) { // Create instance to pass it to modules. - instance := &Instance{ - version: version, - } + instance := &Instance{} + instance.ctx, instance.cancelCtx = context.WithCancel(context.Background()) var err error @@ -142,7 +147,7 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { if err != nil { return nil, fmt.Errorf("create core module: %w", err) } - instance.updates, err = updates.New(instance, svcCfg.ShutdownFunc) + instance.updates, err = updates.New(instance) if err != nil { return nil, fmt.Errorf("create updates module: %w", err) } @@ -228,7 +233,7 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { if err != nil { return nil, fmt.Errorf("create navigator module: %w", err) } - instance.captain, err = captain.New(instance, svcCfg.ShutdownFunc) + instance.captain, err = captain.New(instance) if err != nil { return nil, fmt.Errorf("create captain module: %w", err) } @@ -258,7 +263,7 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { } // Add all modules to instance group. - instance.Group = mgr.NewGroup( + instance.serviceGroup = mgr.NewGroup( instance.database, instance.config, instance.api, @@ -275,8 +280,8 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { instance.ui, instance.profile, - instance.network, instance.netquery, + instance.network, instance.firewall, instance.filterLists, instance.interception, @@ -292,7 +297,7 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { ) // SPN Group - instance.SpnGroup = mgr.NewGroup( + instance.SpnGroup = mgr.NewExtendedGroup( instance.cabin, instance.navigator, instance.captain, @@ -304,9 +309,6 @@ func New(version string, svcCfg *ServiceConfig) (*Instance, error) { instance.terminal, ) - // FIXME: call this before to trigger shutdown/restart event - // core.ShutdownHook() - return instance, nil } @@ -316,11 +318,6 @@ func (i *Instance) SetSleep(enabled bool) { i.captain.SetSleep(enabled) } -// Version returns the version. -func (i *Instance) Version() string { - return i.version -} - // Database returns the database module. func (i *Instance) Database() *dbmodule.DBModule { return i.database @@ -507,7 +504,7 @@ func (i *Instance) Core() *core.Core { } // SPNGroup returns the group of all SPN modules. -func (i *Instance) SPNGroup() *mgr.Group { +func (i *Instance) SPNGroup() *mgr.ExtendedGroup { return i.SpnGroup } @@ -525,10 +522,10 @@ func (i *Instance) SetCmdLineOperation(f func() error) { i.CommandLineOperation = f } -// GetStatus returns the current Status of all group modules. -func (i *Instance) GetStatus() []mgr.StateUpdate { - mainStates := i.Group.GetStatus() - spnStates := i.SpnGroup.GetStatus() +// GetStates returns the current states of all group modules. +func (i *Instance) GetStates() []mgr.StateUpdate { + mainStates := i.serviceGroup.GetStates() + spnStates := i.SpnGroup.GetStates() updates := make([]mgr.StateUpdate, 0, len(mainStates)+len(spnStates)) updates = append(updates, mainStates...) @@ -537,9 +534,68 @@ func (i *Instance) GetStatus() []mgr.StateUpdate { return updates } -// AddStatusCallback adds the given callback function to all group modules that +// AddStatesCallback adds the given callback function to all group modules that // expose a state manager at States(). -func (i *Instance) AddStatusCallback(callbackName string, callback mgr.EventCallbackFunc[mgr.StateUpdate]) { - i.Group.AddStatusCallback(callbackName, callback) - i.SpnGroup.AddStatusCallback(callbackName, callback) +func (i *Instance) AddStatesCallback(callbackName string, callback mgr.EventCallbackFunc[mgr.StateUpdate]) { + i.serviceGroup.AddStatesCallback(callbackName, callback) + i.SpnGroup.AddStatesCallback(callbackName, callback) +} + +// Ready returns whether all modules in the main service module group have been started and are still running. +func (i *Instance) Ready() bool { + return i.serviceGroup.Ready() +} + +// Ctx returns the instance context. +// It is only canceled on shutdown. +func (i *Instance) Ctx() context.Context { + return i.ctx +} + +// Start starts the instance. +func (i *Instance) Start() error { + return i.serviceGroup.Start() +} + +// Stop stops the instance and cancels the instance context when done. +func (i *Instance) Stop() error { + defer i.cancelCtx() + return i.serviceGroup.Stop() +} + +// Shutdown asynchronously stops the instance. +func (i *Instance) Shutdown(exitCode int) { + i.exitCode.Store(int32(exitCode)) + + m := mgr.New("instance") + m.Go("shutdown", func(w *mgr.WorkerCtx) error { + for { + if err := i.Stop(); err != nil { + w.Error("failed to shutdown", "error", err, "retry", "1s") + time.Sleep(1 * time.Second) + } else { + return nil + } + } + }) +} + +// Stopping returns whether the instance is shutting down. +func (i *Instance) Stopping() bool { + return i.ctx.Err() == nil +} + +// Stopped returns a channel that is triggered when the instance has shut down. +func (i *Instance) Stopped() <-chan struct{} { + return i.ctx.Done() +} + +// SetExitCode sets the exit code on the instance. +func (i *Instance) SetExitCode(exitCode int) { + i.exitCode.Store(int32(exitCode)) +} + +// ExitCode returns the set exit code of the instance. +func (i *Instance) ExitCode() int { + return int(i.exitCode.Load()) } diff --git a/service/intel/filterlists/updater.go b/service/intel/filterlists/updater.go index c0be14dac..72f7b82e7 100644 --- a/service/intel/filterlists/updater.go +++ b/service/intel/filterlists/updater.go @@ -24,11 +24,10 @@ var updateInProgress = abool.New() func tryListUpdate(ctx context.Context) error { err := performUpdate(ctx) if err != nil { - // Check if we are shutting down. - // TODO(vladimir): Do we need stopping detection? - // if module.IsStopping() { - // return nil - // } + // Check if we are shutting down, as to not raise a false alarm. + if module.mgr.IsDone() { + return nil + } // Check if the module already has a failure status set. If not, set a // generic one with the returned error. diff --git a/service/mgr/module.go b/service/mgr/group.go similarity index 63% rename from service/mgr/module.go rename to service/mgr/group.go index 45d917949..51d1fab17 100644 --- a/service/mgr/module.go +++ b/service/mgr/group.go @@ -6,11 +6,18 @@ import ( "fmt" "reflect" "strings" - "sync" "sync/atomic" "time" ) +var ( + // ErrUnsuitableGroupState is returned when an operation cannot be executed due to an unsuitable state. + ErrUnsuitableGroupState = errors.New("unsuitable group state") + + // ErrInvalidGroupState is returned when a group is in an invalid state and cannot be recovered. + ErrInvalidGroupState = errors.New("invalid group state") +) + const ( groupStateOff int32 = iota groupStateStarting @@ -40,10 +47,6 @@ func groupStateToString(state int32) string { type Group struct { modules []*groupModule - ctx context.Context - cancelCtx context.CancelFunc - ctxLock sync.Mutex - state atomic.Int32 } @@ -65,11 +68,12 @@ func NewGroup(modules ...Module) *Group { g := &Group{ modules: make([]*groupModule, 0, len(modules)), } - g.initGroupContext() // Initialize groups modules. for _, m := range modules { - // Skip non-values. + mgr := m.Manager() + + // Check module. switch { case m == nil: // Skip nil values to allow for cleaner code. @@ -78,12 +82,19 @@ func NewGroup(modules ...Module) *Group { // If nil values are given via a struct, they are will be interfaces to a // nil type. Ignore these too. continue + case mgr == nil: + // Ignore modules that do not return a manager. + continue + case mgr.Name() == "": + // Force name if none is set. + // TODO: Unsafe if module is already logging, etc. + mgr.setName(makeModuleName(m)) } // Add module to group. g.modules = append(g.modules, &groupModule{ module: m, - mgr: newManager(g.ctx, makeModuleName(m), "module"), + mgr: mgr, }) } @@ -94,12 +105,21 @@ func NewGroup(modules ...Module) *Group { // If a module fails to start, itself and all previous modules // will be stopped in the reverse order. func (g *Group) Start() error { - if !g.state.CompareAndSwap(groupStateOff, groupStateStarting) { - return fmt.Errorf("group is not off, state: %s", groupStateToString(g.state.Load())) + // Check group state. + switch g.state.Load() { + case groupStateRunning: + // Already running. + return nil + case groupStateInvalid: + // Something went terribly wrong, cannot recover from here. + return fmt.Errorf("%w: cannot recover", ErrInvalidGroupState) + default: + if !g.state.CompareAndSwap(groupStateOff, groupStateStarting) { + return fmt.Errorf("%w: group is not off, state: %s", ErrUnsuitableGroupState, groupStateToString(g.state.Load())) + } } - g.initGroupContext() - + // Start modules. for i, m := range g.modules { m.mgr.Info("starting") startTime := time.Now() @@ -118,16 +138,28 @@ func (g *Group) Start() error { duration := time.Since(startTime) m.mgr.Info("started", "time", duration.String()) } + g.state.Store(groupStateRunning) return nil } // Stop stops all modules in the group in the reverse order. func (g *Group) Stop() error { - if !g.state.CompareAndSwap(groupStateRunning, groupStateStopping) { - return fmt.Errorf("group is not running, state: %s", groupStateToString(g.state.Load())) + // Check group state. + switch g.state.Load() { + case groupStateOff: + // Already stopped. + return nil + case groupStateInvalid: + // Something went terribly wrong, cannot recover from here. + return fmt.Errorf("%w: cannot recover", ErrInvalidGroupState) + default: + if !g.state.CompareAndSwap(groupStateRunning, groupStateStopping) { + return fmt.Errorf("%w: group is not running, state: %s", ErrUnsuitableGroupState, groupStateToString(g.state.Load())) + } } + // Stop modules. if !g.stopFrom(len(g.modules) - 1) { g.state.Store(groupStateInvalid) return errors.New("failed to stop") @@ -139,6 +171,8 @@ func (g *Group) Stop() error { func (g *Group) stopFrom(index int) (ok bool) { ok = true + + // Stop modules. for i := index; i >= 0; i-- { m := g.modules[i] @@ -162,42 +196,26 @@ func (g *Group) stopFrom(index int) (ok bool) { } } - g.stopGroupContext() - return -} - -func (g *Group) initGroupContext() { - g.ctxLock.Lock() - defer g.ctxLock.Unlock() - - g.ctx, g.cancelCtx = context.WithCancel(context.Background()) -} - -func (g *Group) stopGroupContext() { - g.ctxLock.Lock() - defer g.ctxLock.Unlock() - - g.cancelCtx() -} - -// Done returns the context Done channel. -func (g *Group) Done() <-chan struct{} { - g.ctxLock.Lock() - defer g.ctxLock.Unlock() + // Reset modules. + if !ok { + // Stopping failed somewhere, reset anyway after a short wait. + // This will be very uncommon and can help to mitigate race conditions in these events. + time.Sleep(time.Second) + } + for _, m := range g.modules { + m.mgr.Reset() + } - return g.ctx.Done() + return ok } -// IsDone checks whether the manager context is done. -func (g *Group) IsDone() bool { - g.ctxLock.Lock() - defer g.ctxLock.Unlock() - - return g.ctx.Err() != nil +// Ready returns whether all modules in the group have been started and are still running. +func (g *Group) Ready() bool { + return g.state.Load() == groupStateRunning } -// GetStatus returns the current Status of all group modules. -func (g *Group) GetStatus() []StateUpdate { +// GetStates returns the current states of all group modules. +func (g *Group) GetStates() []StateUpdate { updates := make([]StateUpdate, 0, len(g.modules)) for _, gm := range g.modules { if stateful, ok := gm.module.(StatefulModule); ok { @@ -207,9 +225,9 @@ func (g *Group) GetStatus() []StateUpdate { return updates } -// AddStatusCallback adds the given callback function to all group modules that +// AddStatesCallback adds the given callback function to all group modules that // expose a state manager at States(). -func (g *Group) AddStatusCallback(callbackName string, callback EventCallbackFunc[StateUpdate]) { +func (g *Group) AddStatesCallback(callbackName string, callback EventCallbackFunc[StateUpdate]) { for _, gm := range g.modules { if stateful, ok := gm.module.(StatefulModule); ok { stateful.States().AddCallback(callbackName, callback) diff --git a/service/mgr/group_ext.go b/service/mgr/group_ext.go new file mode 100644 index 000000000..dcd17236e --- /dev/null +++ b/service/mgr/group_ext.go @@ -0,0 +1,92 @@ +package mgr + +import ( + "context" + "errors" + "sync" + "time" +) + +// ExtendedGroup extends the group with additional helpful functionality. +type ExtendedGroup struct { + *Group + + ensureCtx context.Context + ensureCancel context.CancelFunc + ensureLock sync.Mutex +} + +// NewExtendedGroup returns a new extended group. +func NewExtendedGroup(modules ...Module) *ExtendedGroup { + return UpgradeGroup(NewGroup(modules...)) +} + +// UpgradeGroup upgrades a regular group to an extended group. +func UpgradeGroup(g *Group) *ExtendedGroup { + return &ExtendedGroup{ + Group: g, + ensureCancel: func() {}, + } +} + +// EnsureStartedWorker tries to start the group until it succeeds or fails permanently. +func (eg *ExtendedGroup) EnsureStartedWorker(wCtx *WorkerCtx) error { + // Setup worker. + var ctx context.Context + func() { + eg.ensureLock.Lock() + defer eg.ensureLock.Unlock() + eg.ensureCancel() + eg.ensureCtx, eg.ensureCancel = context.WithCancel(wCtx.Ctx()) + ctx = eg.ensureCtx + }() + + for { + err := eg.Group.Start() + switch { + case err == nil: + return nil + case errors.Is(err, ErrInvalidGroupState): + wCtx.Debug("group start delayed", "error", err) + default: + return err + } + + select { + case <-ctx.Done(): + return nil + case <-time.After(1 * time.Second): + } + } +} + +// EnsureStoppedWorker tries to stop the group until it succeeds or fails permanently. +func (eg *ExtendedGroup) EnsureStoppedWorker(wCtx *WorkerCtx) error { + // Setup worker. + var ctx context.Context + func() { + eg.ensureLock.Lock() + defer eg.ensureLock.Unlock() + eg.ensureCancel() + eg.ensureCtx, eg.ensureCancel = context.WithCancel(wCtx.Ctx()) + ctx = eg.ensureCtx + }() + + for { + err := eg.Group.Stop() + switch { + case err == nil: + return nil + case errors.Is(err, ErrInvalidGroupState): + wCtx.Debug("group stop delayed", "error", err) + default: + return err + } + + select { + case <-ctx.Done(): + return nil + case <-time.After(1 * time.Second): + } + } +} diff --git a/service/mgr/manager.go b/service/mgr/manager.go index cd346c7fa..4be527854 100644 --- a/service/mgr/manager.go +++ b/service/mgr/manager.go @@ -7,6 +7,9 @@ import ( "time" ) +// ManagerNameSLogKey is used as the logging key for the name of the manager. +var ManagerNameSLogKey = "manager" + // Manager manages workers. type Manager struct { name string @@ -21,21 +24,16 @@ type Manager struct { // New returns a new manager. func New(name string) *Manager { - return NewWithContext(context.Background(), name) -} - -// NewWithContext returns a new manager that uses the given context. -func NewWithContext(ctx context.Context, name string) *Manager { - return newManager(ctx, name, "manager") + return newManager(name) } -func newManager(ctx context.Context, name string, logNameKey string) *Manager { +func newManager(name string) *Manager { m := &Manager{ name: name, - logger: slog.Default().With(logNameKey, name), + logger: slog.Default().With(ManagerNameSLogKey, name), workersDone: make(chan struct{}), } - m.ctx, m.cancelCtx = context.WithCancel(ctx) + m.ctx, m.cancelCtx = context.WithCancel(context.Background()) return m } @@ -44,6 +42,13 @@ func (m *Manager) Name() string { return m.name } +// setName sets the manager name and resets the logger to use that name. +// Not safe for concurrent use with any other module methods. +func (m *Manager) setName(newName string) { + m.name = newName + m.logger = slog.Default().With(ManagerNameSLogKey, m.name) +} + // Ctx returns the worker context. func (m *Manager) Ctx() context.Context { return m.ctx @@ -162,3 +167,15 @@ func (m *Manager) workerDone() { } } } + +// Reset resets the manager in order to be able to be used again. +// In the process, the current context is canceled. +// As part of a module (in a group), the module might be stopped and started again. +// This method is not goroutine-safe. The caller must make sure the manager is +// not being used in any way during execution. +func (m *Manager) Reset() { + m.cancelCtx() + m.ctx, m.cancelCtx = context.WithCancel(context.Background()) + m.workerCnt.Store(0) + m.workersDone = make(chan struct{}) +} diff --git a/service/netquery/module_api.go b/service/netquery/module_api.go index 7d7047613..e5250144a 100644 --- a/service/netquery/module_api.go +++ b/service/netquery/module_api.go @@ -284,7 +284,12 @@ func (nq *NetQuery) Start() error { } func (nq *NetQuery) Stop() error { - // we don't use m.Module.Ctx here because it is already cancelled when stop is called. + // Cacnel the module context. + nq.mgr.Cancel() + // Wait for all workers before we start the shutdown. + nq.mgr.WaitForWorkers(time.Minute) + + // we don't use the module ctx here because it is already canceled. // just give the clean up 1 minute to happen and abort otherwise. ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() diff --git a/service/network/clean.go b/service/network/clean.go index 61767dc2c..331161a17 100644 --- a/service/network/clean.go +++ b/service/network/clean.go @@ -53,7 +53,6 @@ func connectionCleaner(ctx *mgr.WorkerCtx) error { func cleanConnections() (activePIDs map[int]struct{}) { activePIDs = make(map[int]struct{}) - // FIXME(vladimir): This was previously a MicroTask but it does not seem right, to run it asynchronously. Is'nt activePIDs going to be used after the function is called? _ = module.mgr.Do("clean connections", func(ctx *mgr.WorkerCtx) error { now := time.Now().UTC() nowUnix := now.Unix() diff --git a/service/resolver/resolver-tcp.go b/service/resolver/resolver-tcp.go index 0e14f3a71..261f0e5bc 100644 --- a/service/resolver/resolver-tcp.go +++ b/service/resolver/resolver-tcp.go @@ -120,18 +120,16 @@ func (tr *TCPResolver) getOrCreateResolverConn(ctx context.Context) (*tcpResolve log.Warningf("resolver: heartbeat for dns client %s failed", tr.resolver.Info.DescriptiveName()) case <-ctx.Done(): return nil, ctx.Err() - // TODO(vladimir): there is no need for this right? - // case <-module.Stopping(): - // return nil, ErrShuttingDown + case <-module.mgr.Done(): + return nil, ErrShuttingDown } } else { // If there is no resolver, check if we are shutting down before dialing! select { case <-ctx.Done(): return nil, ctx.Err() - // TODO(vladimir): there is no need for this right? - // case <-module.Stopping(): - // return nil, ErrShuttingDown + case <-module.mgr.Done(): + return nil, ErrShuttingDown default: } } @@ -207,9 +205,8 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) { case resolverConn.queries <- tq: case <-ctx.Done(): return nil, ctx.Err() - // TODO(vladimir): there is no need for this right? - // case <-module.Stopping(): - // return nil, ErrShuttingDown + case <-module.mgr.Done(): + return nil, ErrShuttingDown case <-time.After(defaultRequestTimeout): return nil, ErrTimeout } @@ -220,9 +217,8 @@ func (tr *TCPResolver) Query(ctx context.Context, q *Query) (*RRCache, error) { case reply = <-tq.Response: case <-ctx.Done(): return nil, ctx.Err() - // TODO(vladimir): there is no need for this right? - // case <-module.Stopping(): - // return nil, ErrShuttingDown + case <-module.mgr.Done(): + return nil, ErrShuttingDown case <-time.After(defaultRequestTimeout): return nil, ErrTimeout } diff --git a/service/status/module.go b/service/status/module.go index 1846c4e8f..77ec6716b 100644 --- a/service/status/module.go +++ b/service/status/module.go @@ -52,9 +52,9 @@ func (s *Status) Start() error { s.statesLock.Lock() defer s.statesLock.Unlock() // Add status callback within the lock so we can force the right order. - s.instance.AddStatusCallback("status update", s.handleModuleStatusUpdate) + s.instance.AddStatesCallback("status update", s.handleModuleStatusUpdate) // Get initial states. - for _, stateUpdate := range s.instance.GetStatus() { + for _, stateUpdate := range s.instance.GetStates() { s.states[stateUpdate.Module] = stateUpdate s.deriveNotificationsFromStateUpdate(stateUpdate) } @@ -101,6 +101,6 @@ func New(instance instance) (*Status, error) { type instance interface { NetEnv() *netenv.NetEnv - GetStatus() []mgr.StateUpdate - AddStatusCallback(callbackName string, callback mgr.EventCallbackFunc[mgr.StateUpdate]) + GetStates() []mgr.StateUpdate + AddStatesCallback(callbackName string, callback mgr.EventCallbackFunc[mgr.StateUpdate]) } diff --git a/service/updates/module.go b/service/updates/module.go index 43c2a48cc..7d5da2b72 100644 --- a/service/updates/module.go +++ b/service/updates/module.go @@ -20,8 +20,7 @@ type Updates struct { EventResourcesUpdated *mgr.EventMgr[struct{}] EventVersionsUpdated *mgr.EventMgr[struct{}] - instance instance - shutdownFunc func(exitCode int) + instance instance } var ( @@ -30,7 +29,7 @@ var ( ) // New returns a new UI module. -func New(instance instance, shutdownFunc func(exitCode int)) (*Updates, error) { +func New(instance instance) (*Updates, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } @@ -45,7 +44,6 @@ func New(instance instance, shutdownFunc func(exitCode int)) (*Updates, error) { EventResourcesUpdated: mgr.NewEventMgr[struct{}](ResourceUpdateEvent, m), EventVersionsUpdated: mgr.NewEventMgr[struct{}](VersionUpdateEvent, m), instance: instance, - shutdownFunc: shutdownFunc, } if err := prep(); err != nil { return nil, err @@ -77,4 +75,5 @@ func (u *Updates) Stop() error { type instance interface { API() *api.API Config() *config.Config + Shutdown(exitCode int) } diff --git a/service/updates/restart.go b/service/updates/restart.go index bed86ea84..b2aa4e33b 100644 --- a/service/updates/restart.go +++ b/service/updates/restart.go @@ -114,9 +114,9 @@ func automaticRestart(w *mgr.WorkerCtx) error { // Set restart exit code. if !rebooting { - module.shutdownFunc(RestartExitCode) + module.instance.Shutdown(RestartExitCode) } else { - module.shutdownFunc(0) + module.instance.Shutdown(0) } } diff --git a/spn/access/client.go b/spn/access/client.go index 70381c855..fddae23df 100644 --- a/spn/access/client.go +++ b/spn/access/client.go @@ -57,16 +57,8 @@ func makeClientRequest(opts *clientRequestOptions) (resp *http.Response, err err // Get context for request. var ctx context.Context var cancel context.CancelFunc - // TODO(vladimir): can the module not be online? - // if module.Online() { - // Only use module context if online. ctx, cancel = context.WithTimeout(module.mgr.Ctx(), opts.requestTimeout) defer cancel() - // } else { - // // Otherwise, use the background context. - // ctx, cancel = context.WithTimeout(context.Background(), opts.requestTimeout) - // defer cancel() - // } // Create new request. request, err := http.NewRequestWithContext(ctx, opts.method, opts.url, nil) diff --git a/spn/access/module.go b/spn/access/module.go index 92857bbc0..c407518a2 100644 --- a/spn/access/module.go +++ b/spn/access/module.go @@ -73,15 +73,28 @@ func prep() error { } func start() error { + // Add config listener to enable/disable SPN. module.instance.Config().EventConfigChange.AddCallback("spn enable check", func(wc *mgr.WorkerCtx, s struct{}) (bool, error) { + // Do not do anything when we are shutting down. + if module.instance.Stopping() { + return true, nil + } + enabled := config.GetAsBool("spn/enable", false) if enabled() { - return false, module.instance.SPNGroup().Start() + module.mgr.Go("ensure SPN is started", module.instance.SPNGroup().EnsureStartedWorker) } else { - return false, module.instance.SPNGroup().Stop() + module.mgr.Go("ensure SPN is stopped", module.instance.SPNGroup().EnsureStoppedWorker) } + return false, nil }) + // Check if we need to enable SPN now. + enabled := config.GetAsBool("spn/enable", false) + if enabled() { + module.mgr.Go("ensure SPN is started", module.instance.SPNGroup().EnsureStartedWorker) + } + // Initialize zones. if err := InitializeZones(); err != nil { return err @@ -99,6 +112,12 @@ func start() error { } func stop() error { + // Make sure SPN is stopped before we proceed. + err := module.mgr.Do("ensure SPN is shut down", module.instance.SPNGroup().EnsureStoppedWorker) + if err != nil { + log.Errorf("access: stop SPN: %w", err) + } + if conf.Client() { // Store tokens to database. storeTokens() @@ -112,7 +131,7 @@ func stop() error { // UpdateAccount updates the user account and fetches new tokens, if needed. func UpdateAccount(_ *mgr.WorkerCtx) error { //, task *modules.Task) error { - // Schedule next call this will change if other conditions are met bellow. + // Schedule next call - this will change if other conditions are met bellow. module.updateAccountWorkerMgr.Delay(24 * time.Hour) // Retry sooner if the token issuer is failing. @@ -186,10 +205,6 @@ func tokenIssuerFailed() { if !tokenIssuerIsFailing.SetToIf(false, true) { return } - // TODO(vladimir): Do we need this check? - // if !module.Online() { - // return - // } module.updateAccountWorkerMgr.Delay(tokenIssuerRetryDuration) } @@ -239,5 +254,6 @@ func New(instance instance) (*Access, error) { type instance interface { Config() *config.Config - SPNGroup() *mgr.Group + SPNGroup() *mgr.ExtendedGroup + Stopping() bool } diff --git a/spn/captain/api.go b/spn/captain/api.go index fc19136f5..ec4987670 100644 --- a/spn/captain/api.go +++ b/spn/captain/api.go @@ -1,7 +1,12 @@ package captain import ( + "fmt" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/database" + "github.com/safing/portmaster/base/database/query" ) const ( @@ -23,41 +28,30 @@ func registerAPIEndpoints() error { } func handleReInit(ar *api.Request) (msg string, err error) { - // FIXME: make a better way to disable and enable spn - // // Disable module and check - // changed := module.Disable() - // if !changed { - // return "", errors.New("can only re-initialize when the SPN is enabled") - // } - - // // Run module manager. - // err = modules.ManageModules() - // if err != nil { - // return "", fmt.Errorf("failed to stop SPN: %w", err) - // } - - // // Delete SPN cache. - // db := database.NewInterface(&database.Options{ - // Local: true, - // Internal: true, - // }) - // deletedRecords, err := db.Purge(ar.Context(), query.New("cache:spn/")) - // if err != nil { - // return "", fmt.Errorf("failed to delete SPN cache: %w", err) - // } - - // // Enable module. - // module.Enable() - - // // Run module manager. - // err = modules.ManageModules() - // if err != nil { - // return "", fmt.Errorf("failed to start SPN after cache reset: %w", err) - // } - - // return fmt.Sprintf( - // "Completed SPN re-initialization and deleted %d cache records in the process.", - // deletedRecords, - // ), nil - return "", nil + // Make sure SPN is stopped and wait for it to complete. + err = module.mgr.Do("stop SPN for re-init", module.instance.SPNGroup().EnsureStoppedWorker) + if err != nil { + return "", fmt.Errorf("failed to stop SPN for re-init: %w", err) + } + + // Delete SPN cache. + db := database.NewInterface(&database.Options{ + Local: true, + Internal: true, + }) + deletedRecords, err := db.Purge(ar.Context(), query.New("cache:spn/")) + if err != nil { + return "", fmt.Errorf("failed to delete SPN cache: %w", err) + } + + // Start SPN if it is enabled. + enabled := config.GetAsBool("spn/enable", false) + if enabled() { + module.mgr.Go("ensure SPN is started", module.instance.SPNGroup().EnsureStartedWorker) + } + + return fmt.Sprintf( + "Completed SPN re-initialization and deleted %d cache records in the process.", + deletedRecords, + ), nil } diff --git a/spn/captain/client.go b/spn/captain/client.go index 4a2546004..0827d317a 100644 --- a/spn/captain/client.go +++ b/spn/captain/client.go @@ -1,7 +1,6 @@ package captain import ( - "context" "errors" "fmt" "time" @@ -41,7 +40,7 @@ func ClientReady() bool { } type ( - clientComponentFunc func(ctx context.Context) clientComponentResult + clientComponentFunc func(ctx *mgr.WorkerCtx) clientComponentResult clientComponentResult uint8 ) @@ -78,7 +77,7 @@ func clientManager(ctx *mgr.WorkerCtx) error { netenv.ConnectedToSPN.UnSet() resetSPNStatus(StatusDisabled, true) module.states.Clear() - clientStopHomeHub(ctx.Ctx()) + clientStopHomeHub(ctx) }() module.states.Add(mgr.State{ @@ -105,10 +104,8 @@ func clientManager(ctx *mgr.WorkerCtx) error { reconnect: for { // Check if we are shutting down. - select { - case <-ctx.Done(): + if ctx.IsDone() { return nil - default: } // Reset SPN status. @@ -126,7 +123,7 @@ reconnect: clientConnectToHomeHub, clientSetActiveConnectionStatus, } { - switch clientFunc(ctx.Ctx()) { + switch clientFunc(ctx) { case clientResultOk: // Continue case clientResultRetry, clientResultReconnect: @@ -167,7 +164,7 @@ reconnect: clientCheckAccountAndTokens, clientSetActiveConnectionStatus, } { - switch clientFunc(ctx.Ctx()) { + switch clientFunc(ctx) { case clientResultOk: // Continue case clientResultRetry: @@ -194,7 +191,7 @@ reconnect: } } -func clientCheckNetworkReady(ctx context.Context) clientComponentResult { +func clientCheckNetworkReady(ctx *mgr.WorkerCtx) clientComponentResult { // Check if we are online enough for connecting. switch netenv.GetOnlineStatus() { //nolint:exhaustive case netenv.StatusOffline, @@ -214,7 +211,7 @@ func clientCheckNetworkReady(ctx context.Context) clientComponentResult { // Attempts to use the same will result in errors. var DisableAccount bool -func clientCheckAccountAndTokens(ctx context.Context) clientComponentResult { +func clientCheckAccountAndTokens(ctx *mgr.WorkerCtx) clientComponentResult { if DisableAccount { return clientResultOk } @@ -321,7 +318,7 @@ func clientCheckAccountAndTokens(ctx context.Context) clientComponentResult { return clientResultOk } -func clientStopHomeHub(ctx context.Context) clientComponentResult { +func clientStopHomeHub(ctx *mgr.WorkerCtx) clientComponentResult { // Don't use the context in this function, as it will likely be canceled // already and would disrupt any context usage in here. @@ -340,9 +337,13 @@ func clientStopHomeHub(ctx context.Context) clientComponentResult { return clientResultOk } -func clientConnectToHomeHub(ctx context.Context) clientComponentResult { +func clientConnectToHomeHub(ctx *mgr.WorkerCtx) clientComponentResult { err := establishHomeHub(ctx) if err != nil { + if ctx.IsDone() { + return clientResultShutdown + } + log.Errorf("spn/captain: failed to establish connection to home hub: %s", err) resetSPNStatus(StatusFailed, true) @@ -397,7 +398,7 @@ func clientConnectToHomeHub(ctx context.Context) clientComponentResult { return clientResultOk } -func clientSetActiveConnectionStatus(ctx context.Context) clientComponentResult { +func clientSetActiveConnectionStatus(ctx *mgr.WorkerCtx) clientComponentResult { // Get current home. home, homeTerminal := navigator.Main.GetHome() if home == nil || homeTerminal == nil { @@ -440,7 +441,7 @@ func clientSetActiveConnectionStatus(ctx context.Context) clientComponentResult return clientResultOk } -func clientCheckHomeHubConnection(ctx context.Context) clientComponentResult { +func clientCheckHomeHubConnection(ctx *mgr.WorkerCtx) clientComponentResult { // Check the status of the Home Hub. home, homeTerminal := navigator.Main.GetHome() if home == nil || homeTerminal == nil || homeTerminal.IsBeingAbandoned() { @@ -462,7 +463,7 @@ func clientCheckHomeHubConnection(ctx context.Context) clientComponentResult { // Prepare to reconnect to the network. // Reset all failing states, as these might have been caused by the failing home hub. - navigator.Main.ResetFailingStates(ctx) + navigator.Main.ResetFailingStates() // If the last health check is clearly too long ago, assume that the device was sleeping and do not set the home node to failing yet. if time.Since(lastHealthCheck) > clientHealthCheckTickDuration+ @@ -482,7 +483,7 @@ func clientCheckHomeHubConnection(ctx context.Context) clientComponentResult { return clientResultOk } -func pingHome(ctx context.Context, t terminal.Terminal, timeout time.Duration) (latency time.Duration, err *terminal.Error) { +func pingHome(ctx *mgr.WorkerCtx, t terminal.Terminal, timeout time.Duration) (latency time.Duration, err *terminal.Error) { started := time.Now() // Start ping operation. diff --git a/spn/captain/module.go b/spn/captain/module.go index a8ba0e567..0c3eb906e 100644 --- a/spn/captain/module.go +++ b/spn/captain/module.go @@ -33,8 +33,6 @@ type Captain struct { mgr *mgr.Manager instance instance - shutdownFunc func(exitCode int) - healthCheckTicker *mgr.SleepyTicker maintainPublicStatus *mgr.WorkerMgr @@ -140,7 +138,7 @@ func start() error { // Load identity. if err := loadPublicIdentity(); err != nil { // We cannot recover from this, set controlled failure (do not retry). - module.shutdownFunc(controlledFailureExitCode) + module.instance.Shutdown(controlledFailureExitCode) return err } @@ -148,7 +146,7 @@ func start() error { // Check if any networks are configured. if !conf.HubHasIPv4() && !conf.HubHasIPv6() { // We cannot recover from this, set controlled failure (do not retry). - module.shutdownFunc(controlledFailureExitCode) + module.instance.Shutdown(controlledFailureExitCode) return errors.New("no IP addresses for Hub configured (or detected)") } @@ -191,7 +189,7 @@ func start() error { // Reset failing hubs when the network changes while not connected. module.instance.NetEnv().EventNetworkChange.AddCallback("reset failing hubs", func(_ *mgr.WorkerCtx, _ struct{}) (bool, error) { if ready.IsNotSet() { - navigator.Main.ResetFailingStates(module.mgr.Ctx()) + navigator.Main.ResetFailingStates() } return false, nil }) @@ -245,15 +243,14 @@ var ( ) // New returns a new Captain module. -func New(instance instance, shutdownFunc func(exitCode int)) (*Captain, error) { +func New(instance instance) (*Captain, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") } m := mgr.New("Captain") module = &Captain{ - mgr: m, - instance: instance, - shutdownFunc: shutdownFunc, + mgr: m, + instance: instance, states: mgr.NewStateMgr(m), EventSPNConnected: mgr.NewEventMgr[struct{}](SPNConnectedEvent, m), @@ -272,4 +269,6 @@ type instance interface { Patrol() *patrol.Patrol Config() *config.Config Updates() *updates.Updates + SPNGroup() *mgr.ExtendedGroup + Shutdown(exitCode int) } diff --git a/spn/captain/navigation.go b/spn/captain/navigation.go index 8b2a57a66..ac757616f 100644 --- a/spn/captain/navigation.go +++ b/spn/captain/navigation.go @@ -28,7 +28,7 @@ var ( ErrReInitSPNSuggested = errors.New("SPN re-init suggested") ) -func establishHomeHub(ctx context.Context) error { +func establishHomeHub(ctx *mgr.WorkerCtx) error { // Get own IP. locations, ok := netenv.GetInternetLocation() if !ok || len(locations.All) == 0 { @@ -46,10 +46,10 @@ func establishHomeHub(ctx context.Context) error { var myEntity *intel.Entity if dl := locations.BestV4(); dl != nil && dl.IP != nil { myEntity = (&intel.Entity{IP: dl.IP}).Init(0) - myEntity.FetchData(ctx) + myEntity.FetchData(ctx.Ctx()) } else if dl := locations.BestV6(); dl != nil && dl.IP != nil { myEntity = (&intel.Entity{IP: dl.IP}).Init(0) - myEntity.FetchData(ctx) + myEntity.FetchData(ctx.Ctx()) } // Get home hub policy for selecting the home hub. @@ -112,8 +112,8 @@ findCandidates: err = connectToHomeHub(ctx, candidate) if err != nil { // Check if context is canceled. - if ctx.Err() != nil { - return ctx.Err() + if ctx.IsDone() { + return ctx.Ctx().Err() } // Check if the SPN protocol is stopping again. if errors.Is(err, terminal.ErrStopping) { @@ -131,12 +131,12 @@ findCandidates: return errors.New("no home hub candidates available") } -func connectToHomeHub(ctx context.Context, dst *hub.Hub) error { +func connectToHomeHub(wCtx *mgr.WorkerCtx, dst *hub.Hub) error { // Create new context with timeout. // The maximum timeout is a worst case safeguard. // Keep in mind that multiple IPs and protocols may be tried in all configurations. // Some servers will be (possibly on purpose) hard to reach. - ctx, cancel := context.WithTimeout(ctx, 5*time.Minute) + ctx, cancel := context.WithTimeout(wCtx.Ctx(), 5*time.Minute) defer cancel() // Set and clean up exceptions. diff --git a/spn/navigator/update.go b/spn/navigator/update.go index dfe0c54c6..f0b1a3397 100644 --- a/spn/navigator/update.go +++ b/spn/navigator/update.go @@ -409,7 +409,7 @@ func (m *Map) updateHubLane(pin *Pin, lane *hub.Lane, peer *Pin) { } // ResetFailingStates resets the failing state on all pins. -func (m *Map) ResetFailingStates(ctx context.Context) { +func (m *Map) ResetFailingStates() { m.Lock() defer m.Unlock() From f51c5590cbeb3676bb8e361cafd40e210278785b Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov Date: Thu, 25 Jul 2024 17:05:03 +0300 Subject: [PATCH 37/56] [WIP] Fix most unit tests --- base/api/authentication_test.go | 8 -- base/api/init_test.go | 38 ++++++ base/api/main_test.go | 55 -------- base/config/get_test.go | 8 +- base/config/init_test.go | 35 +++++ base/config/main.go | 16 +++ base/database/interface_cache_test.go | 29 +++-- base/rng/rng_test.go | 8 +- base/rng/test/main.go | 71 +++++++--- base/template/module.go | 109 ---------------- base/template/module_test.go | 53 -------- cmds/hub/main.go | 83 ++++++------ cmds/notifier/main.go | 8 +- cmds/observation-hub/apprise.go | 48 ++++++- cmds/observation-hub/main.go | 17 ++- cmds/observation-hub/observe.go | 50 ++++++- service/core/pmtesting/testing.go | 136 -------------------- service/intel/geoip/init_test.go | 102 +++++++++++++++ service/intel/geoip/lookup_test.go | 12 ++ service/intel/geoip/module_test.go | 11 -- service/netenv/init_test.go | 99 ++++++++++++++ service/netenv/main_test.go | 12 +- service/process/module_test.go | 11 -- service/profile/endpoints/endpoints_test.go | 96 +++++++++++++- service/resolver/main_test.go | 125 +++++++++++++++++- spn/access/module.go | 2 +- spn/access/module_test.go | 48 ++++++- spn/access/token/module_test.go | 20 ++- spn/cabin/identity_test.go | 2 +- spn/cabin/keys_test.go | 2 +- spn/cabin/module_test.go | 53 +++++++- spn/cabin/verification_test.go | 2 +- spn/crew/module_test.go | 127 +++++++++++++++++- spn/crew/op_connect_test.go | 2 +- spn/docks/bandwidth_test.go | 2 +- spn/docks/crane_test.go | 12 +- spn/docks/module_test.go | 135 ++++++++++++++++++- spn/docks/op_capacity_test.go | 2 +- spn/docks/op_latency_test.go | 2 +- spn/docks/terminal_expansion_test.go | 12 +- spn/hub/hub_test.go | 105 ++++++++++++++- spn/navigator/module_test.go | 122 +++++++++++++++++- spn/ships/http_shared_test.go | 7 +- spn/terminal/module_test.go | 115 ++++++++++++++++- spn/terminal/terminal_test.go | 6 +- spn/terminal/testing.go | 66 +++++----- spn/unit/scheduler_test.go | 16 ++- spn/unit/unit_test.go | 16 ++- 48 files changed, 1532 insertions(+), 584 deletions(-) create mode 100644 base/api/init_test.go delete mode 100644 base/api/main_test.go create mode 100644 base/config/init_test.go delete mode 100644 base/template/module.go delete mode 100644 base/template/module_test.go delete mode 100644 service/core/pmtesting/testing.go create mode 100644 service/intel/geoip/init_test.go delete mode 100644 service/intel/geoip/module_test.go create mode 100644 service/netenv/init_test.go delete mode 100644 service/process/module_test.go diff --git a/base/api/authentication_test.go b/base/api/authentication_test.go index 3d7e7c504..40ce0efe7 100644 --- a/base/api/authentication_test.go +++ b/base/api/authentication_test.go @@ -55,14 +55,6 @@ func makeAuthTestPath(reading bool, p Permission) string { return fmt.Sprintf("/test/auth/write/%s", p) } -func init() { - // Set test authenticator. - err := SetAuthenticator(testAuthenticator) - if err != nil { - panic(err) - } -} - func TestPermissions(t *testing.T) { t.Parallel() diff --git a/base/api/init_test.go b/base/api/init_test.go new file mode 100644 index 000000000..a7e04c51f --- /dev/null +++ b/base/api/init_test.go @@ -0,0 +1,38 @@ +package api + +import ( + "testing" + + "github.com/safing/portmaster/base/config" +) + +type testInstance struct { + config *config.Config +} + +var _ instance = &testInstance{} + +func (stub *testInstance) Config() *config.Config { + return stub.config +} + +func (stub *testInstance) SetCmdLineOperation(f func() error) {} + +func (stub *testInstance) Ready() bool { + return true +} + +func TestMain(m *testing.M) { + SetDefaultAPIListenAddress("0.0.0.0:8080") + instance := &testInstance{} + var err error + module, err = New(instance) + if err != nil { + panic(err) + } + err = SetAuthenticator(testAuthenticator) + if err != nil { + panic(err) + } + m.Run() +} diff --git a/base/api/main_test.go b/base/api/main_test.go deleted file mode 100644 index 6f68ad241..000000000 --- a/base/api/main_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package api - -import ( -// "fmt" -// "os" -// "testing" - -// API depends on the database for the database api. -// _ "github.com/safing/portmaster/base/database/dbmodule" -// "github.com/safing/portmaster/base/dataroot" -) - -func init() { - defaultListenAddress = "127.0.0.1:8817" -} - -// func TestMain(m *testing.M) { -// // enable module for testing -// module.Enable() - -// // tmp dir for data root (db & config) -// tmpDir, err := os.MkdirTemp("", "portbase-testing-") -// if err != nil { -// fmt.Fprintf(os.Stderr, "failed to create tmp dir: %s\n", err) -// os.Exit(1) -// } -// // initialize data dir -// err = dataroot.Initialize(tmpDir, 0o0755) -// if err != nil { -// fmt.Fprintf(os.Stderr, "failed to initialize data root: %s\n", err) -// os.Exit(1) -// } - -// // start modules -// var exitCode int -// err = modules.Start() -// if err != nil { -// // starting failed -// fmt.Fprintf(os.Stderr, "failed to setup test: %s\n", err) -// exitCode = 1 -// } else { -// // run tests -// exitCode = m.Run() -// } - -// // shutdown -// _ = modules.Shutdown() -// if modules.GetExitStatusCode() != 0 { -// exitCode = modules.GetExitStatusCode() -// fmt.Fprintf(os.Stderr, "failed to cleanly shutdown test: %s\n", err) -// } -// // clean up and exit -// _ = os.RemoveAll(tmpDir) -// os.Exit(exitCode) -// } diff --git a/base/config/get_test.go b/base/config/get_test.go index 810631abc..d16b6c512 100644 --- a/base/config/get_test.go +++ b/base/config/get_test.go @@ -307,7 +307,7 @@ func BenchmarkGetAsStringCached(b *testing.B) { b.ResetTimer() // Start benchmark - for i := 0; i < b.N; i++ { + for range b.N { monkey() } } @@ -325,7 +325,7 @@ func BenchmarkGetAsStringRefetch(b *testing.B) { b.ResetTimer() // Start benchmark - for i := 0; i < b.N; i++ { + for range b.N { getValueCache("monkey", nil, OptTypeString) } } @@ -344,7 +344,7 @@ func BenchmarkGetAsIntCached(b *testing.B) { b.ResetTimer() // Start benchmark - for i := 0; i < b.N; i++ { + for range b.N { elephant() } } @@ -362,7 +362,7 @@ func BenchmarkGetAsIntRefetch(b *testing.B) { b.ResetTimer() // Start benchmark - for i := 0; i < b.N; i++ { + for range b.N { getValueCache("elephant", nil, OptTypeInt) } } diff --git a/base/config/init_test.go b/base/config/init_test.go new file mode 100644 index 000000000..53967044b --- /dev/null +++ b/base/config/init_test.go @@ -0,0 +1,35 @@ +package config + +import ( + "fmt" + "os" + "testing" +) + +type testInstance struct{} + +var _ instance = testInstance{} + +func (stub testInstance) SetCmdLineOperation(f func() error) {} + +func runTest(m *testing.M) error { + ds, err := InitializeUnitTestDataroot("test-config") + if err != nil { + return fmt.Errorf("failed to initialize dataroot: %w", err) + } + defer func() { _ = os.RemoveAll(ds) }() + module, err = New(&testInstance{}) + if err != nil { + return fmt.Errorf("failed to initialize module: %w", err) + } + + m.Run() + return nil +} + +func TestMain(m *testing.M) { + if err := runTest(m); err != nil { + fmt.Printf("%s\n", err) + os.Exit(1) + } +} diff --git a/base/config/main.go b/base/config/main.go index 528e22871..b4eef3846 100644 --- a/base/config/main.go +++ b/base/config/main.go @@ -134,3 +134,19 @@ func GetActiveConfigValues() map[string]interface{} { return values } + +func InitializeUnitTestDataroot(testName string) (string, error) { + basePath, err := os.MkdirTemp("", fmt.Sprintf("portmaster-%s", testName)) + if err != nil { + return "", fmt.Errorf("failed to make tmp dir: %w", err) + } + + ds := utils.NewDirStructure(basePath, 0o0755) + SetDataRoot(ds) + err = dataroot.Initialize(basePath, 0o0755) + if err != nil { + return "", fmt.Errorf("failed to initialize dataroot: %w", err) + } + + return basePath, nil +} diff --git a/base/database/interface_cache_test.go b/base/database/interface_cache_test.go index cfed4388c..f3d8af36a 100644 --- a/base/database/interface_cache_test.go +++ b/base/database/interface_cache_test.go @@ -1,11 +1,12 @@ package database import ( - "context" "fmt" "strconv" "sync" "testing" + + "github.com/safing/portmaster/service/mgr" ) func benchmarkCacheWriting(b *testing.B, storageType string, cacheSize int, sampleSize int, delayWrites bool) { //nolint:gocognit,gocyclo,thelper @@ -35,22 +36,23 @@ func benchmarkCacheWriting(b *testing.B, storageType string, cacheSize int, samp db := NewInterface(options) // Start - ctx, cancelCtx := context.WithCancel(context.Background()) + m := mgr.New("Cache writing benchmark test") var wg sync.WaitGroup if cacheSize > 0 && delayWrites { wg.Add(1) - go func() { - err := db.DelayedCacheWriter(ctx) + m.Go("Cache writing benchmark worker", func(wc *mgr.WorkerCtx) error { + err := db.DelayedCacheWriter(wc) if err != nil { panic(err) } wg.Done() - }() + return nil + }) } // Start Benchmark. b.ResetTimer() - for i := 0; i < b.N; i++ { + for i := range b.N { testRecordID := i % sampleSize r := NewExample( dbName+":"+strconv.Itoa(testRecordID), @@ -64,7 +66,7 @@ func benchmarkCacheWriting(b *testing.B, storageType string, cacheSize int, samp } // End cache writer and wait - cancelCtx() + m.Cancel() wg.Wait() }) } @@ -96,23 +98,24 @@ func benchmarkCacheReadWrite(b *testing.B, storageType string, cacheSize int, sa db := NewInterface(options) // Start - ctx, cancelCtx := context.WithCancel(context.Background()) + m := mgr.New("Cache read/write benchmark test") var wg sync.WaitGroup if cacheSize > 0 && delayWrites { wg.Add(1) - go func() { - err := db.DelayedCacheWriter(ctx) + m.Go("Cache read/write benchmark worker", func(wc *mgr.WorkerCtx) error { + err := db.DelayedCacheWriter(wc) if err != nil { panic(err) } wg.Done() - }() + return nil + }) } // Start Benchmark. b.ResetTimer() writing := true - for i := 0; i < b.N; i++ { + for i := range b.N { testRecordID := i % sampleSize key := dbName + ":" + strconv.Itoa(testRecordID) @@ -132,7 +135,7 @@ func benchmarkCacheReadWrite(b *testing.B, storageType string, cacheSize int, sa } // End cache writer and wait - cancelCtx() + m.Cancel() wg.Wait() }) } diff --git a/base/rng/rng_test.go b/base/rng/rng_test.go index be70e17df..20cb6aaa5 100644 --- a/base/rng/rng_test.go +++ b/base/rng/rng_test.go @@ -5,7 +5,13 @@ import ( ) func init() { - err := start() + var err error + module, err = New(struct{}{}) + if err != nil { + panic(err) + } + + err = module.Start() if err != nil { panic(err) } diff --git a/base/rng/test/main.go b/base/rng/test/main.go index 68ad0cbe2..896d86ae7 100644 --- a/base/rng/test/main.go +++ b/base/rng/test/main.go @@ -1,43 +1,64 @@ package main import ( - "context" "crypto/aes" "crypto/cipher" "crypto/rand" "encoding/binary" "encoding/hex" + "errors" "fmt" "io" "os" "runtime" "strconv" + "sync/atomic" "time" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/rng" + "github.com/safing/portmaster/service/core/base" + "github.com/safing/portmaster/service/mgr" ) +type Test struct { + mgr *mgr.Manager + + instance instance +} + var ( - module *modules.Module + module *Test + shimLoaded atomic.Bool outputFile *os.File outputSize uint64 = 1000000 ) func init() { - module = modules.Register("main", prep, start, nil, "rng") + // module = modules.Register("main", prep, start, nil, "rng") } func main() { runtime.GOMAXPROCS(1) - os.Exit(run.Run()) + var err error + module, err = New(struct{}{}) + if err != nil { + fmt.Printf("failed to initialize module: %s", err) + return + } + + err = start() + if err != nil { + fmt.Printf("failed to initialize module: %s", err) + return + } } func prep() error { if len(os.Args) < 3 { fmt.Printf("usage: ./%s {fortuna|tickfeeder} [output size in MB]", os.Args[0]) - return modules.ErrCleanExit + return base.ErrCleanExit } switch os.Args[1] { @@ -72,11 +93,11 @@ func start() error { switch os.Args[1] { case "fortuna": - module.StartWorker("fortuna", fortuna) + module.mgr.Go("fortuna", fortuna) case "tickfeeder": - module.StartWorker("noise", noise) - module.StartWorker("tickfeeder", tickfeeder) + module.mgr.Go("noise", noise) + module.mgr.Go("tickfeeder", tickfeeder) default: return fmt.Errorf("usage: ./%s {fortuna|tickfeeder}", os.Args[0]) @@ -85,11 +106,11 @@ func start() error { return nil } -func fortuna(_ context.Context) error { +func fortuna(_ *mgr.WorkerCtx) error { var bytesWritten uint64 for { - if module.IsStopping() { + if module.mgr.IsDone() { return nil } @@ -115,17 +136,17 @@ func fortuna(_ context.Context) error { } } - go modules.Shutdown() //nolint:errcheck + go module.mgr.Cancel() //nolint:errcheck return nil } -func tickfeeder(ctx context.Context) error { +func tickfeeder(ctx *mgr.WorkerCtx) error { var bytesWritten uint64 var value int64 var pushes int for { - if module.IsStopping() { + if module.mgr.IsDone() { return nil } @@ -157,11 +178,11 @@ func tickfeeder(ctx context.Context) error { } } - go modules.Shutdown() //nolint:errcheck + go module.mgr.Cancel() //nolint:errcheck return nil } -func noise(ctx context.Context) error { +func noise(ctx *mgr.WorkerCtx) error { // do some aes ctr for noise key, _ := hex.DecodeString("6368616e676520746869732070617373") @@ -187,3 +208,23 @@ func noise(ctx context.Context) error { } } } + +func New(instance instance) (*Test, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + m := mgr.New("geoip") + module = &Test{ + mgr: m, + instance: instance, + } + + if err := prep(); err != nil { + return nil, err + } + + return module, nil +} + +type instance interface{} diff --git a/base/template/module.go b/base/template/module.go deleted file mode 100644 index bbf3b71ad..000000000 --- a/base/template/module.go +++ /dev/null @@ -1,109 +0,0 @@ -package template - -// import ( -// "context" -// "time" - -// "github.com/safing/portmaster/base/config" -// ) - -// const ( -// eventStateUpdate = "state update" -// ) - -// var module *modules.Module - -// func init() { -// // register module -// module = modules.Register("template", prep, start, stop) // add dependencies... -// subsystems.Register( -// "template-subsystem", // ID -// "Template Subsystem", // name -// "This subsystem is a template for quick setup", // description -// module, -// "config:template", // key space for configuration options registered -// &config.Option{ -// Name: "Template Subsystem", -// Key: "config:subsystems/template", -// Description: "This option enables the Template Subsystem [TEMPLATE]", -// OptType: config.OptTypeBool, -// DefaultValue: false, -// }, -// ) - -// // register events that other modules can subscribe to -// module.RegisterEvent(eventStateUpdate, true) -// } - -// func prep() error { -// // register options -// err := config.Register(&config.Option{ -// Name: "language", -// Key: "template/language", -// Description: "Sets the language for the template [TEMPLATE]", -// OptType: config.OptTypeString, -// ExpertiseLevel: config.ExpertiseLevelUser, // default -// ReleaseLevel: config.ReleaseLevelStable, // default -// RequiresRestart: false, // default -// DefaultValue: "en", -// ValidationRegex: "^[a-z]{2}$", -// }) -// if err != nil { -// return err -// } - -// // register event hooks -// // do this in prep() and not in start(), as we don't want to register again if module is turned off and on again -// err = module.RegisterEventHook( -// "template", // event source module name -// "state update", // event source name -// "react to state changes", // description of hook function -// eventHandler, // hook function -// ) -// if err != nil { -// return err -// } - -// // hint: event hooks and tasks will not be run if module isn't online -// return nil -// } - -// func start() error { -// // register tasks -// module.NewTask("do something", taskFn).Queue() - -// // start service worker -// module.StartServiceWorker("do something", 0, serviceWorker) - -// return nil -// } - -// func stop() error { -// return nil -// } - -// func serviceWorker(ctx context.Context) error { -// for { -// select { -// case <-time.After(1 * time.Second): -// err := do() -// if err != nil { -// return err -// } -// case <-ctx.Done(): -// return nil -// } -// } -// } - -// func taskFn(ctx context.Context, task *modules.Task) error { -// return do() -// } - -// func eventHandler(ctx context.Context, data interface{}) error { -// return do() -// } - -// func do() error { -// return nil -// } diff --git a/base/template/module_test.go b/base/template/module_test.go deleted file mode 100644 index 824c195ef..000000000 --- a/base/template/module_test.go +++ /dev/null @@ -1,53 +0,0 @@ -package template - -import ( - "fmt" - "os" - "testing" - - _ "github.com/safing/portmaster/base/database/dbmodule" - "github.com/safing/portmaster/base/dataroot" -) - -func TestMain(m *testing.M) { - // register base module, for database initialization - modules.Register("base", nil, nil, nil) - - // enable module for testing - module.Enable() - - // tmp dir for data root (db & config) - tmpDir, err := os.MkdirTemp("", "portbase-testing-") - if err != nil { - fmt.Fprintf(os.Stderr, "failed to create tmp dir: %s\n", err) - os.Exit(1) - } - // initialize data dir - err = dataroot.Initialize(tmpDir, 0o0755) - if err != nil { - fmt.Fprintf(os.Stderr, "failed to initialize data root: %s\n", err) - os.Exit(1) - } - - // start modules - var exitCode int - err = modules.Start() - if err != nil { - // starting failed - fmt.Fprintf(os.Stderr, "failed to setup test: %s\n", err) - exitCode = 1 - } else { - // run tests - exitCode = m.Run() - } - - // shutdown - _ = modules.Shutdown() - if modules.GetExitStatusCode() != 0 { - exitCode = modules.GetExitStatusCode() - fmt.Fprintf(os.Stderr, "failed to cleanly shutdown test: %s\n", err) - } - // clean up and exit - _ = os.RemoveAll(tmpDir) - os.Exit(exitCode) -} diff --git a/cmds/hub/main.go b/cmds/hub/main.go index dada02d05..4180e74d4 100644 --- a/cmds/hub/main.go +++ b/cmds/hub/main.go @@ -2,18 +2,18 @@ package main import ( "flag" - "fmt" - "os" - "runtime" + // "fmt" + // "os" + // "runtime" - "github.com/safing/portmaster/base/info" - "github.com/safing/portmaster/base/metrics" + // "github.com/safing/portmaster/base/info" + // "github.com/safing/portmaster/base/metrics" _ "github.com/safing/portmaster/service/core/base" _ "github.com/safing/portmaster/service/ui" "github.com/safing/portmaster/service/updates" - "github.com/safing/portmaster/service/updates/helper" + // "github.com/safing/portmaster/service/updates/helper" _ "github.com/safing/portmaster/spn/captain" - "github.com/safing/portmaster/spn/conf" + // "github.com/safing/portmaster/spn/conf" ) func init() { @@ -21,44 +21,45 @@ func init() { } func main() { - info.Set("SPN Hub", "0.7.7", "GPLv3") + // FIXME: rewrite so it fits the new module system + // info.Set("SPN Hub", "0.7.7", "GPLv3") - // Configure metrics. - _ = metrics.SetNamespace("hub") + // // Configure metrics. + // _ = metrics.SetNamespace("hub") - // Configure updating. - updates.UserAgent = fmt.Sprintf("SPN Hub (%s %s)", runtime.GOOS, runtime.GOARCH) - helper.IntelOnly() + // // Configure updating. + // updates.UserAgent = fmt.Sprintf("SPN Hub (%s %s)", runtime.GOOS, runtime.GOARCH) + // helper.IntelOnly() - // Configure SPN mode. - conf.EnablePublicHub(true) - conf.EnableClient(false) + // // Configure SPN mode. + // conf.EnablePublicHub(true) + // conf.EnableClient(false) - // Disable module management, as we want to start all modules. - modules.DisableModuleManagement() + // // Disable module management, as we want to start all modules. + // modules.DisableModuleManagement() - // Configure microtask threshold. - // Scale with CPU/GOMAXPROCS count, but keep a baseline and minimum: - // CPUs -> MicroTasks - // 0 -> 8 (increased to minimum) - // 1 -> 8 (increased to minimum) - // 2 -> 8 - // 3 -> 10 - // 4 -> 12 - // 8 -> 20 - // 16 -> 36 - // - // Start with number of GOMAXPROCS. - microTasksThreshold := runtime.GOMAXPROCS(0) * 2 - // Use at least 4 microtasks based on GOMAXPROCS. - if microTasksThreshold < 4 { - microTasksThreshold = 4 - } - // Add a 4 microtask baseline. - microTasksThreshold += 4 - // Set threshold. - modules.SetMaxConcurrentMicroTasks(microTasksThreshold) + // // Configure microtask threshold. + // // Scale with CPU/GOMAXPROCS count, but keep a baseline and minimum: + // // CPUs -> MicroTasks + // // 0 -> 8 (increased to minimum) + // // 1 -> 8 (increased to minimum) + // // 2 -> 8 + // // 3 -> 10 + // // 4 -> 12 + // // 8 -> 20 + // // 16 -> 36 + // // + // // Start with number of GOMAXPROCS. + // microTasksThreshold := runtime.GOMAXPROCS(0) * 2 + // // Use at least 4 microtasks based on GOMAXPROCS. + // if microTasksThreshold < 4 { + // microTasksThreshold = 4 + // } + // // Add a 4 microtask baseline. + // microTasksThreshold += 4 + // // Set threshold. + // modules.SetMaxConcurrentMicroTasks(microTasksThreshold) - // Start. - os.Exit(run.Run()) + // // Start. + // os.Exit(run.Run()) } diff --git a/cmds/notifier/main.go b/cmds/notifier/main.go index 164aeb003..e40487bbc 100644 --- a/cmds/notifier/main.go +++ b/cmds/notifier/main.go @@ -76,10 +76,10 @@ func main() { } // print help - if modules.HelpFlag { - flag.Usage() - os.Exit(0) - } + // if modules.HelpFlag { + // flag.Usage() + // os.Exit(0) + // } if showVersion { fmt.Println(info.FullVersion()) diff --git a/cmds/observation-hub/apprise.go b/cmds/observation-hub/apprise.go index 501c06131..c64d3dedb 100644 --- a/cmds/observation-hub/apprise.go +++ b/cmds/observation-hub/apprise.go @@ -9,17 +9,38 @@ import ( "fmt" "net/http" "strings" + "sync/atomic" "text/template" "time" "github.com/safing/portmaster/base/apprise" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/intel/geoip" + "github.com/safing/portmaster/service/mgr" ) +type Apprise struct { + mgr *mgr.Manager + + instance instance +} + +func (a *Apprise) Manager() *mgr.Manager { + return a.mgr +} + +func (a *Apprise) Start() error { + return startApprise() +} + +func (a *Apprise) Stop() error { + return nil +} + var ( - appriseModule *modules.Module - appriseNotifier *apprise.Notifier + appriseModule *Apprise + appriseShimLoaded atomic.Bool + appriseNotifier *apprise.Notifier appriseURL string appriseTag string @@ -29,7 +50,7 @@ var ( ) func init() { - appriseModule = modules.Register("apprise", nil, startApprise, nil) + // appriseModule = modules.Register("apprise", nil, startApprise, nil) flag.StringVar(&appriseURL, "apprise-url", "", "set the apprise URL to enable notifications via apprise") flag.StringVar(&appriseTag, "apprise-tag", "", "set the apprise tag(s) according to their docs") @@ -77,7 +98,7 @@ func startApprise() error { } if appriseGreet { - err := appriseNotifier.Send(appriseModule.Ctx, &apprise.Message{ + err := appriseNotifier.Send(appriseModule.mgr.Ctx(), &apprise.Message{ Title: "👋 Observation Hub Reporting In", Body: "I am the Observation Hub. I am connected to the SPN and watch out for it. I will report notable changes to the network here.", }) @@ -100,7 +121,7 @@ func reportToApprise(change *observedChange) (errs error) { handleTag: for _, tag := range strings.Split(appriseNotifier.DefaultTag, ",") { // Check if we are shutting down. - if appriseModule.IsStopping() { + if appriseModule.mgr.IsDone() { return nil } @@ -128,7 +149,7 @@ handleTag: var err error for i := 0; i < 3; i++ { // Try three times. - err = appriseNotifier.Send(appriseModule.Ctx, &apprise.Message{ + err = appriseNotifier.Send(appriseModule.mgr.Ctx(), &apprise.Message{ Body: buf.String(), Tag: tag, }) @@ -254,3 +275,18 @@ func getCountryInfo(code string) geoip.CountryInfo { // panic(err) // } // } + +// New returns a new Apprise module. +func NewApprise(instance instance) (*Observer, error) { + if !appriseShimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + m := mgr.New("apprise") + appriseModule = &Apprise{ + mgr: m, + instance: instance, + } + + return observerModule, nil +} diff --git a/cmds/observation-hub/main.go b/cmds/observation-hub/main.go index cf3c11144..bf83a4ab4 100644 --- a/cmds/observation-hub/main.go +++ b/cmds/observation-hub/main.go @@ -2,7 +2,6 @@ package main import ( "fmt" - "os" "runtime" "github.com/safing/portmaster/base/api" @@ -34,9 +33,21 @@ func main() { sluice.EnableListener = false api.EnableServer = false + /// TODO(vladimir) initialize dependency modules + // Disable module management, as we want to start all modules. - modules.DisableModuleManagement() + // module.DisableModuleManagement() + module, err := New(struct{}{}) + if err != nil { + fmt.Printf("error creating observer: %s\n", err) + return + } + err = module.Start() + if err != nil { + fmt.Printf("failed to start observer: %s\n", err) + return + } // Start. - os.Exit(run.Run()) + // os.Exit(run.Start()) } diff --git a/cmds/observation-hub/observe.go b/cmds/observation-hub/observe.go index cec4c687c..76f0c0a96 100644 --- a/cmds/observation-hub/observe.go +++ b/cmds/observation-hub/observe.go @@ -1,12 +1,12 @@ package main import ( - "context" "errors" "flag" "fmt" "path" "strings" + "sync/atomic" "time" diff "github.com/r3labs/diff/v3" @@ -15,12 +15,31 @@ import ( "github.com/safing/portmaster/base/database" "github.com/safing/portmaster/base/database/query" "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/captain" "github.com/safing/portmaster/spn/navigator" ) +type Observer struct { + mgr *mgr.Manager + instance instance +} + +func (o *Observer) Manager() *mgr.Manager { + return o.mgr +} + +func (o *Observer) Start() error { + return startObserver() +} + +func (o *Observer) Stop() error { + return nil +} + var ( - observerModule *modules.Module + observerModule *Observer + shimLoaded atomic.Bool db = database.NewInterface(&database.Options{ Local: true, @@ -36,7 +55,7 @@ var ( ) func init() { - observerModule = modules.Register("observer", prepObserver, startObserver, nil, "captain", "apprise") + // observerModule = modules.Register("observer", prepObserver, startObserver, nil, "captain", "apprise") flag.BoolVar(&reportAllChanges, "report-all-changes", false, "report all changes, no just interesting ones") flag.StringVar(&reportingDelayFlag, "reporting-delay", "10m", "delay reports to summarize changes") @@ -55,7 +74,7 @@ func prepObserver() error { } func startObserver() error { - observerModule.StartServiceWorker("observer", 0, observerWorker) + observerModule.mgr.Go("observer", observerWorker) return nil } @@ -78,7 +97,7 @@ type observedChange struct { SPNStatus *captain.SPNStatus } -func observerWorker(ctx context.Context) error { +func observerWorker(ctx *mgr.WorkerCtx) error { log.Info("observer: starting") defer log.Info("observer: stopped") @@ -404,3 +423,24 @@ func makeHubName(name, id string) string { return fmt.Sprintf("%s (%s)", name, shortenedID) } } + +// New returns a new Observer module. +func New(instance instance) (*Observer, error) { + if !shimLoaded.CompareAndSwap(false, true) { + return nil, errors.New("only one instance allowed") + } + + m := mgr.New("observer") + observerModule = &Observer{ + mgr: m, + instance: instance, + } + + if err := prepObserver(); err != nil { + return nil, err + } + + return observerModule, nil +} + +type instance interface{} diff --git a/service/core/pmtesting/testing.go b/service/core/pmtesting/testing.go deleted file mode 100644 index 41eae4b60..000000000 --- a/service/core/pmtesting/testing.go +++ /dev/null @@ -1,136 +0,0 @@ -// Package pmtesting provides a simple unit test setup routine. -// -// Usage: -// -// package name -// -// import ( -// "testing" -// -// "github.com/safing/portmaster/service/core/pmtesting" -// ) -// -// func TestMain(m *testing.M) { -// pmtesting.TestMain(m, module) -// } -package pmtesting - -import ( - "flag" - "fmt" - "os" - "path/filepath" - "runtime/pprof" - "testing" - - _ "github.com/safing/portmaster/base/database/storage/hashmap" - "github.com/safing/portmaster/base/dataroot" - "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/service/core/base" -) - -var printStackOnExit bool - -func init() { - flag.BoolVar(&printStackOnExit, "print-stack-on-exit", false, "prints the stack before of shutting down") -} - -// TestHookFunc describes the functions passed to TestMainWithHooks. -type TestHookFunc func() error - -// TestMain provides a simple unit test setup routine. -func TestMain(m *testing.M, module *modules.Module) { - TestMainWithHooks(m, module, nil, nil) -} - -// TestMainWithHooks provides a simple unit test setup routine and calls -// afterStartFn after modules have started and beforeStopFn before modules -// are shutdown. -func TestMainWithHooks(m *testing.M, module *modules.Module, afterStartFn, beforeStopFn TestHookFunc) { - // Only enable needed modules. - modules.EnableModuleManagement(nil) - - // Enable this module for testing. - if module != nil { - module.Enable() - } - - // switch databases to memory only - base.DefaultDatabaseStorageType = "hashmap" - - // switch API to high port - base.DefaultAPIListenAddress = "127.0.0.1:10817" - - // set log level - log.SetLogLevel(log.TraceLevel) - - // tmp dir for data root (db & config) - tmpDir := filepath.Join(os.TempDir(), "portmaster-testing") - // initialize data dir - err := dataroot.Initialize(tmpDir, 0o0755) - if err != nil { - fmt.Fprintf(os.Stderr, "failed to initialize data root: %s\n", err) - os.Exit(1) - } - - // start modules - var exitCode int - err = modules.Start() - if err != nil { - // starting failed - fmt.Fprintf(os.Stderr, "failed to setup test: %s\n", err) - exitCode = 1 - } else { - runTests := true - if afterStartFn != nil { - if err := afterStartFn(); err != nil { - fmt.Fprintf(os.Stderr, "failed to run test start hook: %s\n", err) - runTests = false - exitCode = 1 - } - } - - if runTests { - // run tests - exitCode = m.Run() - } - } - - if beforeStopFn != nil { - if err := beforeStopFn(); err != nil { - fmt.Fprintf(os.Stderr, "failed to run test shutdown hook: %s\n", err) - } - } - - // shutdown - _ = modules.Shutdown() - if modules.GetExitStatusCode() != 0 { - exitCode = modules.GetExitStatusCode() - fmt.Fprintf(os.Stderr, "failed to cleanly shutdown test: %s\n", err) - } - printStack() - - // clean up and exit - - // Important: Do not remove tmpDir, as it is used as a cache for updates. - // remove config - _ = os.Remove(filepath.Join(tmpDir, "config.json")) - // remove databases - _ = os.Remove(filepath.Join(tmpDir, "databases.json")) - _ = os.RemoveAll(filepath.Join(tmpDir, "databases")) - - os.Exit(exitCode) -} - -func printStack() { - if printStackOnExit { - fmt.Println("=== PRINTING TRACES ===") - fmt.Println("=== GOROUTINES ===") - _ = pprof.Lookup("goroutine").WriteTo(os.Stdout, 2) - fmt.Println("=== BLOCKING ===") - _ = pprof.Lookup("block").WriteTo(os.Stdout, 2) - fmt.Println("=== MUTEXES ===") - _ = pprof.Lookup("mutex").WriteTo(os.Stdout, 2) - fmt.Println("=== END TRACES ===") - } -} diff --git a/service/intel/geoip/init_test.go b/service/intel/geoip/init_test.go new file mode 100644 index 000000000..52c540324 --- /dev/null +++ b/service/intel/geoip/init_test.go @@ -0,0 +1,102 @@ +package geoip + +import ( + "fmt" + "os" + "testing" + + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/database/dbmodule" + "github.com/safing/portmaster/service/updates" +) + +type testInstance struct { + db *dbmodule.DBModule + api *api.API + config *config.Config + updates *updates.Updates +} + +var _ instance = &testInstance{} + +func (stub *testInstance) Updates() *updates.Updates { + return stub.updates +} + +func (stub *testInstance) API() *api.API { + return stub.api +} + +func (stub *testInstance) Config() *config.Config { + return stub.config +} + +func (stub *testInstance) Ready() bool { + return true +} + +func (stub *testInstance) Shutdown(exitCode int) {} + +func (stub *testInstance) SetCmdLineOperation(f func() error) {} + +func runTest(m *testing.M) error { + api.SetDefaultAPIListenAddress("0.0.0.0:8080") + ds, err := config.InitializeUnitTestDataroot("test-geoip") + if err != nil { + return fmt.Errorf("failed to initialize dataroot: %w", err) + } + defer func() { _ = os.RemoveAll(ds) }() + + stub := &testInstance{} + stub.db, err = dbmodule.New(stub) + if err != nil { + return fmt.Errorf("failed to create database: %w", err) + } + stub.config, err = config.New(stub) + if err != nil { + return fmt.Errorf("failed to create config: %w", err) + } + stub.api, err = api.New(stub) + if err != nil { + return fmt.Errorf("failed to create api: %w", err) + } + stub.updates, err = updates.New(stub) + if err != nil { + return fmt.Errorf("failed to create updates: %w", err) + } + module, err = New(stub) + if err != nil { + return fmt.Errorf("failed to initialize module: %w", err) + } + + err = stub.db.Start() + if err != nil { + return fmt.Errorf("Failed to start database: %w", err) + } + err = stub.config.Start() + if err != nil { + return fmt.Errorf("Failed to start config: %w", err) + } + err = stub.api.Start() + if err != nil { + return fmt.Errorf("Failed to start api: %w", err) + } + err = stub.updates.Start() + if err != nil { + return fmt.Errorf("Failed to start updates: %w", err) + } + err = module.Start() + if err != nil { + return fmt.Errorf("failed to start module: %w", err) + } + m.Run() + return nil +} + +func TestMain(m *testing.M) { + if err := runTest(m); err != nil { + fmt.Printf("%s\n", err) + os.Exit(1) + } +} diff --git a/service/intel/geoip/lookup_test.go b/service/intel/geoip/lookup_test.go index 0b9ca27b0..4f882377a 100644 --- a/service/intel/geoip/lookup_test.go +++ b/service/intel/geoip/lookup_test.go @@ -3,6 +3,7 @@ package geoip import ( "net" "testing" + "time" ) func TestLocationLookup(t *testing.T) { @@ -12,6 +13,17 @@ func TestLocationLookup(t *testing.T) { } t.Parallel() + // Wait for db to be initialized + worker.v4.rw.Lock() + waiter := worker.v4.getWaiter() + worker.v4.rw.Unlock() + + worker.triggerUpdate() + select { + case <-waiter: + case <-time.After(15 * time.Second): + } + ip1 := net.ParseIP("81.2.69.142") loc1, err := GetLocation(ip1) if err != nil { diff --git a/service/intel/geoip/module_test.go b/service/intel/geoip/module_test.go deleted file mode 100644 index c223d9209..000000000 --- a/service/intel/geoip/module_test.go +++ /dev/null @@ -1,11 +0,0 @@ -package geoip - -import ( - "testing" - - "github.com/safing/portmaster/service/core/pmtesting" -) - -func TestMain(m *testing.M) { - pmtesting.TestMain(m, module) -} diff --git a/service/netenv/init_test.go b/service/netenv/init_test.go new file mode 100644 index 000000000..bce6c4925 --- /dev/null +++ b/service/netenv/init_test.go @@ -0,0 +1,99 @@ +package netenv + +import ( + "fmt" + "os" + "testing" + + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/database/dbmodule" + "github.com/safing/portmaster/service/updates" +) + +type testInstance struct { + db *dbmodule.DBModule + api *api.API + config *config.Config + updates *updates.Updates +} + +var _ instance = &testInstance{} + +func (stub *testInstance) Updates() *updates.Updates { + return stub.updates +} + +func (stub *testInstance) API() *api.API { + return stub.api +} + +func (stub *testInstance) Config() *config.Config { + return stub.config +} + +func (stub *testInstance) Ready() bool { + return true +} + +func (stub *testInstance) Shutdown(exitCode int) {} + +func (stub *testInstance) SetCmdLineOperation(f func() error) {} + +func runTest(m *testing.M) error { + api.SetDefaultAPIListenAddress("0.0.0.0:8080") + ds, err := config.InitializeUnitTestDataroot("test-netenv") + if err != nil { + return fmt.Errorf("failed to initialize dataroot: %w", err) + } + defer func() { _ = os.RemoveAll(ds) }() + + stub := &testInstance{} + stub.db, err = dbmodule.New(stub) + if err != nil { + return fmt.Errorf("failed to create database: %w", err) + } + stub.config, err = config.New(stub) + if err != nil { + return fmt.Errorf("failed to create config: %w", err) + } + stub.api, err = api.New(stub) + if err != nil { + return fmt.Errorf("failed to create api: %w", err) + } + stub.updates, err = updates.New(stub) + if err != nil { + return fmt.Errorf("failed to create updates: %w", err) + } + err = stub.db.Start() + if err != nil { + return fmt.Errorf("Failed to start database: %w", err) + } + err = stub.config.Start() + if err != nil { + return fmt.Errorf("Failed to start config: %w", err) + } + err = stub.api.Start() + if err != nil { + return fmt.Errorf("Failed to start api: %w", err) + } + err = stub.updates.Start() + if err != nil { + return fmt.Errorf("Failed to start updates: %w", err) + } + + _, err = New(stub) + if err != nil { + return fmt.Errorf("failed to initialize module %s", err) + } + + m.Run() + return nil +} + +func TestMain(m *testing.M) { + if err := runTest(m); err != nil { + fmt.Printf("%s", err) + os.Exit(1) + } +} diff --git a/service/netenv/main_test.go b/service/netenv/main_test.go index 64588b387..791936296 100644 --- a/service/netenv/main_test.go +++ b/service/netenv/main_test.go @@ -1,11 +1,5 @@ package netenv -import ( - "testing" - - "github.com/safing/portmaster/service/core/pmtesting" -) - -func TestMain(m *testing.M) { - pmtesting.TestMain(m, module) -} +// func TestMain(m *testing.M) { +// pmtesting.TestMain(m, module) +// } diff --git a/service/process/module_test.go b/service/process/module_test.go deleted file mode 100644 index f2350d947..000000000 --- a/service/process/module_test.go +++ /dev/null @@ -1,11 +0,0 @@ -package process - -import ( - "testing" - - "github.com/safing/portmaster/service/core/pmtesting" -) - -func TestMain(m *testing.M) { - pmtesting.TestMain(m, module) -} diff --git a/service/profile/endpoints/endpoints_test.go b/service/profile/endpoints/endpoints_test.go index 8dafe10d2..f4ed1b562 100644 --- a/service/profile/endpoints/endpoints_test.go +++ b/service/profile/endpoints/endpoints_test.go @@ -2,18 +2,110 @@ package endpoints import ( "context" + "fmt" "net" + "os" "runtime" "testing" "github.com/stretchr/testify/assert" - "github.com/safing/portmaster/service/core/pmtesting" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/database/dbmodule" "github.com/safing/portmaster/service/intel" + "github.com/safing/portmaster/service/intel/geoip" + "github.com/safing/portmaster/service/updates" ) +type testInstance struct { + db *dbmodule.DBModule + api *api.API + config *config.Config + updates *updates.Updates + geoip *geoip.GeoIP +} + +func (stub *testInstance) Updates() *updates.Updates { + return stub.updates +} + +func (stub *testInstance) API() *api.API { + return stub.api +} + +func (stub *testInstance) Config() *config.Config { + return stub.config +} + +func (stub *testInstance) Ready() bool { + return true +} + +func (stub *testInstance) Shutdown(exitCode int) {} + +func (stub *testInstance) SetCmdLineOperation(f func() error) {} + +func runTest(m *testing.M) error { + api.SetDefaultAPIListenAddress("0.0.0.0:8080") + ds, err := config.InitializeUnitTestDataroot("test-endpoints") + if err != nil { + return fmt.Errorf("failed to initialize dataroot: %w", err) + } + defer func() { _ = os.RemoveAll(ds) }() + + stub := &testInstance{} + stub.db, err = dbmodule.New(stub) + if err != nil { + return fmt.Errorf("failed to create database: %w", err) + } + stub.config, err = config.New(stub) + if err != nil { + return fmt.Errorf("failed to create config: %w", err) + } + stub.api, err = api.New(stub) + if err != nil { + return fmt.Errorf("failed to create api: %w", err) + } + stub.updates, err = updates.New(stub) + if err != nil { + return fmt.Errorf("failed to create updates: %w", err) + } + stub.geoip, err = geoip.New(stub) + if err != nil { + return fmt.Errorf("failed to create geoip: %w", err) + } + + err = stub.db.Start() + if err != nil { + return fmt.Errorf("Failed to start database: %w", err) + } + err = stub.config.Start() + if err != nil { + return fmt.Errorf("Failed to start config: %w", err) + } + err = stub.api.Start() + if err != nil { + return fmt.Errorf("Failed to start api: %w", err) + } + err = stub.updates.Start() + if err != nil { + return fmt.Errorf("Failed to start updates: %w", err) + } + err = stub.geoip.Start() + if err != nil { + return fmt.Errorf("Failed to start geoip: %w", err) + } + + m.Run() + return nil +} + func TestMain(m *testing.M) { - pmtesting.TestMain(m, intel.Module) + if err := runTest(m); err != nil { + fmt.Printf("%s", err) + os.Exit(1) + } } func testEndpointMatch(t *testing.T, ep Endpoint, entity *intel.Entity, expectedResult EPResult) { diff --git a/service/resolver/main_test.go b/service/resolver/main_test.go index 2a2dbe447..44890ed05 100644 --- a/service/resolver/main_test.go +++ b/service/resolver/main_test.go @@ -1,15 +1,136 @@ package resolver import ( + "fmt" + "os" "testing" - "github.com/safing/portmaster/service/core/pmtesting" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/database/dbmodule" + "github.com/safing/portmaster/service/core/base" + "github.com/safing/portmaster/service/mgr" + "github.com/safing/portmaster/service/netenv" + "github.com/safing/portmaster/service/updates" ) var domainFeed = make(chan string) +type testInstance struct { + db *dbmodule.DBModule + base *base.Base + api *api.API + config *config.Config + updates *updates.Updates + netenv *netenv.NetEnv +} + +// var _ instance = &testInstance{} + +func (stub *testInstance) Updates() *updates.Updates { + return stub.updates +} + +func (stub *testInstance) API() *api.API { + return stub.api +} + +func (stub *testInstance) Config() *config.Config { + return stub.config +} + +func (stub *testInstance) NetEnv() *netenv.NetEnv { + return stub.netenv +} + +func (stub *testInstance) Ready() bool { + return true +} + +func (stub *testInstance) Shutdown(exitCode int) {} + +func (stub *testInstance) SetCmdLineOperation(f func() error) {} + +func (i *testInstance) GetEventSPNConnected() *mgr.EventMgr[struct{}] { + return mgr.NewEventMgr[struct{}]("spn connect", nil) +} + +func runTest(m *testing.M) error { + api.SetDefaultAPIListenAddress("0.0.0.0:8080") + ds, err := config.InitializeUnitTestDataroot("test-resolver") + if err != nil { + return fmt.Errorf("failed to initialize dataroot: %w", err) + } + defer func() { _ = os.RemoveAll(ds) }() + + stub := &testInstance{} + stub.db, err = dbmodule.New(stub) + if err != nil { + return fmt.Errorf("failed to create database: %w", err) + } + stub.config, err = config.New(stub) + if err != nil { + return fmt.Errorf("failed to create config: %w", err) + } + stub.base, err = base.New(stub) + if err != nil { + return fmt.Errorf("failed to create base: %w", err) + } + stub.api, err = api.New(stub) + if err != nil { + return fmt.Errorf("failed to create api: %w", err) + } + stub.netenv, err = netenv.New(stub) + if err != nil { + return fmt.Errorf("failed to create netenv: %w", err) + } + stub.updates, err = updates.New(stub) + if err != nil { + return fmt.Errorf("failed to create updates: %w", err) + } + module, err := New(stub) + if err != nil { + return fmt.Errorf("failed to create module: %w", err) + } + + err = stub.db.Start() + if err != nil { + return fmt.Errorf("Failed to start database: %w", err) + } + err = stub.config.Start() + if err != nil { + return fmt.Errorf("Failed to start config: %w", err) + } + err = stub.base.Start() + if err != nil { + return fmt.Errorf("Failed to start base: %w", err) + } + err = stub.api.Start() + if err != nil { + return fmt.Errorf("Failed to start api: %w", err) + } + err = stub.updates.Start() + if err != nil { + return fmt.Errorf("Failed to start updates: %w", err) + } + err = stub.netenv.Start() + if err != nil { + return fmt.Errorf("Failed to start netenv: %w", err) + } + err = module.Start() + if err != nil { + return fmt.Errorf("Failed to start module: %w", err) + } + + m.Run() + return nil +} + func TestMain(m *testing.M) { - pmtesting.TestMain(m, module) + if err := runTest(m); err != nil { + fmt.Printf("%s", err) + os.Exit(1) + } } func init() { diff --git a/spn/access/module.go b/spn/access/module.go index c407518a2..5902e7d5c 100644 --- a/spn/access/module.go +++ b/spn/access/module.go @@ -115,7 +115,7 @@ func stop() error { // Make sure SPN is stopped before we proceed. err := module.mgr.Do("ensure SPN is shut down", module.instance.SPNGroup().EnsureStoppedWorker) if err != nil { - log.Errorf("access: stop SPN: %w", err) + log.Errorf("access: stop SPN: %s", err) } if conf.Client() { diff --git a/spn/access/module_test.go b/spn/access/module_test.go index 59d69be67..5b0180101 100644 --- a/spn/access/module_test.go +++ b/spn/access/module_test.go @@ -1,13 +1,57 @@ package access import ( + "fmt" + "os" "testing" - "github.com/safing/portmaster/service/core/pmtesting" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/conf" ) +type testInstance struct { + config *config.Config +} + +func (stub *testInstance) Config() *config.Config { + return stub.config +} + +func (stub *testInstance) SPNGroup() *mgr.ExtendedGroup { + return nil +} + +func (stub *testInstance) Stopping() bool { + return false +} +func (stub *testInstance) SetCmdLineOperation(f func() error) {} + func TestMain(m *testing.M) { + instance := &testInstance{} + var err error + instance.config, err = config.New(instance) + if err != nil { + fmt.Printf("failed to create config module: %s", err) + os.Exit(0) + } + module, err = New(instance) + if err != nil { + fmt.Printf("failed to create access module: %s", err) + os.Exit(0) + } + + err = instance.config.Start() + if err != nil { + fmt.Printf("failed to start config module: %s", err) + os.Exit(0) + } + err = module.Start() + if err != nil { + fmt.Printf("failed to start access module: %s", err) + os.Exit(0) + } + conf.EnableClient(true) - pmtesting.TestMain(m, module) + m.Run() } diff --git a/spn/access/token/module_test.go b/spn/access/token/module_test.go index 2d00460aa..9696df67b 100644 --- a/spn/access/token/module_test.go +++ b/spn/access/token/module_test.go @@ -1,12 +1,26 @@ package token import ( + "fmt" + "os" "testing" - "github.com/safing/portmaster/service/core/pmtesting" + "github.com/safing/portmaster/base/rng" ) +type testInstance struct{} + func TestMain(m *testing.M) { - module := modules.Register("token", nil, nil, nil, "rng") - pmtesting.TestMain(m, module) + rng, err := rng.New(testInstance{}) + if err != nil { + fmt.Printf("failed to create RNG module: %s", err) + os.Exit(1) + } + + err = rng.Start() + if err != nil { + fmt.Printf("failed to start RNG module: %s", err) + os.Exit(1) + } + m.Run() } diff --git a/spn/cabin/identity_test.go b/spn/cabin/identity_test.go index 6ad0530d4..e0e9ea4f2 100644 --- a/spn/cabin/identity_test.go +++ b/spn/cabin/identity_test.go @@ -20,7 +20,7 @@ func TestIdentity(t *testing.T) { // Create new identity. identityTestKey := "core:spn/public/identity" - id, err := CreateIdentity(module.Ctx, conf.MainMapName) + id, err := CreateIdentity(module.m.Ctx(), conf.MainMapName) if err != nil { t.Fatal(err) } diff --git a/spn/cabin/keys_test.go b/spn/cabin/keys_test.go index c1622fe63..4d135f012 100644 --- a/spn/cabin/keys_test.go +++ b/spn/cabin/keys_test.go @@ -10,7 +10,7 @@ import ( func TestKeyMaintenance(t *testing.T) { t.Parallel() - id, err := CreateIdentity(module.Ctx, conf.MainMapName) + id, err := CreateIdentity(module.m.Ctx(), conf.MainMapName) if err != nil { t.Fatal(err) } diff --git a/spn/cabin/module_test.go b/spn/cabin/module_test.go index c2d66ed1a..13387a1d6 100644 --- a/spn/cabin/module_test.go +++ b/spn/cabin/module_test.go @@ -1,13 +1,60 @@ package cabin import ( + "fmt" + "os" "testing" - "github.com/safing/portmaster/service/core/pmtesting" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/conf" ) -func TestMain(m *testing.M) { +type testInstance struct { + config *config.Config +} + +func (stub *testInstance) Config() *config.Config { + return stub.config +} + +func (stub *testInstance) SPNGroup() *mgr.ExtendedGroup { + return nil +} + +func (stub *testInstance) Stopping() bool { + return false +} +func (stub *testInstance) SetCmdLineOperation(f func() error) {} + +func runTest(m *testing.M) error { + instance := &testInstance{} + var err error + instance.config, err = config.New(instance) + if err != nil { + return fmt.Errorf("failed to create config module: %w", err) + } + module, err = New(struct{}{}) + if err != nil { + return fmt.Errorf("failed to create cabin module: %w", err) + } + err = instance.config.Start() + if err != nil { + return fmt.Errorf("failed to start config module: %w", err) + } + err = module.Start() + if err != nil { + return fmt.Errorf("failed to start cabin module: %w", err) + } conf.EnablePublicHub(true) - pmtesting.TestMain(m, module) + + m.Run() + return nil +} + +func TestMain(m *testing.M) { + if err := runTest(m); err != nil { + fmt.Printf("%s\n", err) + os.Exit(1) + } } diff --git a/spn/cabin/verification_test.go b/spn/cabin/verification_test.go index cb743a3d8..a498bff27 100644 --- a/spn/cabin/verification_test.go +++ b/spn/cabin/verification_test.go @@ -8,7 +8,7 @@ import ( func TestVerification(t *testing.T) { t.Parallel() - id, err := CreateIdentity(module.Ctx, "test") + id, err := CreateIdentity(module.m.Ctx(), "test") if err != nil { t.Fatal(err) } diff --git a/spn/crew/module_test.go b/spn/crew/module_test.go index 7c0a7ad74..d31184f8a 100644 --- a/spn/crew/module_test.go +++ b/spn/crew/module_test.go @@ -1,13 +1,134 @@ package crew import ( + "fmt" + "os" "testing" - "github.com/safing/portmaster/service/core/pmtesting" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/database/dbmodule" + "github.com/safing/portmaster/base/metrics" + "github.com/safing/portmaster/base/rng" + "github.com/safing/portmaster/service/core/base" + "github.com/safing/portmaster/service/mgr" + "github.com/safing/portmaster/spn/cabin" "github.com/safing/portmaster/spn/conf" + "github.com/safing/portmaster/spn/terminal" ) -func TestMain(m *testing.M) { +type testInstance struct { + db *dbmodule.DBModule + config *config.Config + metrics *metrics.Metrics + rng *rng.Rng + base *base.Base + terminal *terminal.TerminalModule + cabin *cabin.Cabin +} + +func (stub *testInstance) Config() *config.Config { + return stub.config +} + +func (stub *testInstance) Metrics() *metrics.Metrics { + return stub.metrics +} + +func (stub *testInstance) SPNGroup() *mgr.ExtendedGroup { + return nil +} + +func (stub *testInstance) Stopping() bool { + return false +} +func (stub *testInstance) SetCmdLineOperation(f func() error) {} + +func runTest(m *testing.M) error { + conf.EnablePublicHub(true) // Make hub config available. + ds, err := config.InitializeUnitTestDataroot("test-crew") + if err != nil { + return fmt.Errorf("failed to initialize dataroot: %w", err) + } + defer func() { _ = os.RemoveAll(ds) }() + + instance := &testInstance{} + // Init + instance.db, err = dbmodule.New(instance) + if err != nil { + return fmt.Errorf("failed to create database module: %w", err) + } + instance.config, err = config.New(instance) + if err != nil { + return fmt.Errorf("failed to create config module: %w", err) + } + instance.metrics, err = metrics.New(instance) + if err != nil { + return fmt.Errorf("failed to create metrics module: %w", err) + } + instance.rng, err = rng.New(instance) + if err != nil { + return fmt.Errorf("failed to create rng module: %w", err) + } + instance.base, err = base.New(instance) + if err != nil { + return fmt.Errorf("failed to create base module: %w", err) + } + instance.terminal, err = terminal.New(instance) + if err != nil { + return fmt.Errorf("failed to create terminal module: %w", err) + } + instance.cabin, err = cabin.New(instance) + if err != nil { + return fmt.Errorf("failed to create cabin module: %w", err) + } + module, err = New(instance) + if err != nil { + return fmt.Errorf("failed to create crew module: %w", err) + } + + // Start + err = instance.db.Start() + if err != nil { + return fmt.Errorf("failed to start db module: %w", err) + } + err = instance.config.Start() + if err != nil { + return fmt.Errorf("failed to start config module: %w", err) + } + err = instance.metrics.Start() + if err != nil { + return fmt.Errorf("failed to start metrics module: %w", err) + } + err = instance.rng.Start() + if err != nil { + return fmt.Errorf("failed to start rng module: %w", err) + } + err = instance.base.Start() + if err != nil { + return fmt.Errorf("failed to start base module: %w", err) + } + err = instance.terminal.Start() + if err != nil { + return fmt.Errorf("failed to start terminal module: %w", err) + } + err = instance.cabin.Start() + if err != nil { + return fmt.Errorf("failed to start cabin module: %w", err) + } + err = module.Start() + if err != nil { + return fmt.Errorf("failed to start crew module: %w", err) + } + conf.EnablePublicHub(true) - pmtesting.TestMain(m, module) + m.Run() + + return nil +} + +func TestMain(m *testing.M) { + if err := runTest(m); err != nil { + fmt.Printf("%s", err) + os.Exit(1) + } } diff --git a/spn/crew/op_connect_test.go b/spn/crew/op_connect_test.go index 7205ea9af..9a7e24f01 100644 --- a/spn/crew/op_connect_test.go +++ b/spn/crew/op_connect_test.go @@ -44,7 +44,7 @@ func TestConnectOp(t *testing.T) { // Set up connect op. b.GrantPermission(terminal.MayConnect) conf.EnablePublicHub(true) - identity, err := cabin.CreateIdentity(module.Ctx, "test") + identity, err := cabin.CreateIdentity(module.mgr.Ctx(), "test") if err != nil { t.Fatalf("failed to create identity: %s", err) } diff --git a/spn/docks/bandwidth_test.go b/spn/docks/bandwidth_test.go index 1924be694..3599ce9c6 100644 --- a/spn/docks/bandwidth_test.go +++ b/spn/docks/bandwidth_test.go @@ -66,7 +66,7 @@ func TestEffectiveBandwidth(t *testing.T) { //nolint:paralleltest // Run alone. t.Fatal(tErr) } // Start handler. - module.StartWorker("op capacity handler", op.handler) + module.mgr.Go("op capacity handler", op.handler) // Wait for result and check error. tErr = <-op.Result() diff --git a/spn/docks/crane_test.go b/spn/docks/crane_test.go index 9e13b5e10..90b3a93ee 100644 --- a/spn/docks/crane_test.go +++ b/spn/docks/crane_test.go @@ -47,7 +47,7 @@ func testCraneWithCounter(t *testing.T, testID string, encrypting bool, loadSize if err != nil { panic(fmt.Sprintf("crane test %s could not create crane1: %s", testID, err)) } - err = crane1.Start(module.Ctx) + err = crane1.Start(module.mgr.Ctx()) if err != nil { panic(fmt.Sprintf("crane test %s could not start crane1: %s", testID, err)) } @@ -59,7 +59,7 @@ func testCraneWithCounter(t *testing.T, testID string, encrypting bool, loadSize if err != nil { panic(fmt.Sprintf("crane test %s could not create crane2: %s", testID, err)) } - err = crane2.Start(module.Ctx) + err = crane2.Start(module.mgr.Ctx()) if err != nil { panic(fmt.Sprintf("crane test %s could not start crane2: %s", testID, err)) } @@ -122,7 +122,7 @@ func (t *StreamingTerminal) ID() uint32 { } func (t *StreamingTerminal) Ctx() context.Context { - return module.Ctx + return module.mgr.Ctx() } func (t *StreamingTerminal) Deliver(msg *terminal.Msg) *terminal.Error { @@ -170,7 +170,7 @@ func testCraneWithStreaming(t *testing.T, testID string, encrypting bool, loadSi if err != nil { panic(fmt.Sprintf("crane test %s could not create crane1: %s", testID, err)) } - err = crane1.Start(module.Ctx) + err = crane1.Start(module.mgr.Ctx()) if err != nil { panic(fmt.Sprintf("crane test %s could not start crane1: %s", testID, err)) } @@ -182,7 +182,7 @@ func testCraneWithStreaming(t *testing.T, testID string, encrypting bool, loadSi if err != nil { panic(fmt.Sprintf("crane test %s could not create crane2: %s", testID, err)) } - err = crane2.Start(module.Ctx) + err = crane2.Start(module.mgr.Ctx()) if err != nil { panic(fmt.Sprintf("crane test %s could not start crane2: %s", testID, err)) } @@ -257,7 +257,7 @@ func getTestIdentity(t *testing.T) (*cabin.Identity, *hub.Hub) { if testIdentity == nil { var err error - testIdentity, err = cabin.CreateIdentity(module.Ctx, "test") + testIdentity, err = cabin.CreateIdentity(module.mgr.Ctx(), "test") if err != nil { t.Fatalf("failed to create identity: %s", err) } diff --git a/spn/docks/module_test.go b/spn/docks/module_test.go index 0383cc216..80acc96d2 100644 --- a/spn/docks/module_test.go +++ b/spn/docks/module_test.go @@ -1,16 +1,145 @@ package docks import ( + "fmt" + "os" "testing" - "github.com/safing/portmaster/service/core/pmtesting" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/database/dbmodule" + "github.com/safing/portmaster/base/metrics" + "github.com/safing/portmaster/base/rng" + "github.com/safing/portmaster/service/core/base" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/access" + "github.com/safing/portmaster/spn/cabin" "github.com/safing/portmaster/spn/conf" + "github.com/safing/portmaster/spn/terminal" ) -func TestMain(m *testing.M) { +type testInstance struct { + db *dbmodule.DBModule + config *config.Config + metrics *metrics.Metrics + rng *rng.Rng + base *base.Base + access *access.Access + terminal *terminal.TerminalModule + cabin *cabin.Cabin +} + +func (stub *testInstance) Config() *config.Config { + return stub.config +} + +func (stub *testInstance) Metrics() *metrics.Metrics { + return stub.metrics +} + +func (stub *testInstance) SPNGroup() *mgr.ExtendedGroup { + return nil +} + +func (stub *testInstance) Stopping() bool { + return false +} +func (stub *testInstance) SetCmdLineOperation(f func() error) {} + +func runTest(m *testing.M) error { + ds, err := config.InitializeUnitTestDataroot("test-docks") + if err != nil { + return fmt.Errorf("failed to initialize dataroot: %w", err) + } + defer func() { _ = os.RemoveAll(ds) }() + + instance := &testInstance{} runningTests = true conf.EnablePublicHub(true) // Make hub config available. access.EnableTestMode() // Register test zone instead of real ones. - pmtesting.TestMain(m, module) + + // Init + instance.db, err = dbmodule.New(instance) + if err != nil { + return fmt.Errorf("failed to create database module: %w\n", err) + } + instance.config, err = config.New(instance) + if err != nil { + return fmt.Errorf("failed to create config module: %w\n", err) + } + instance.metrics, err = metrics.New(instance) + if err != nil { + return fmt.Errorf("failed to create metrics module: %w\n", err) + } + instance.rng, err = rng.New(instance) + if err != nil { + return fmt.Errorf("failed to create rng module: %w\n", err) + } + instance.base, err = base.New(instance) + if err != nil { + return fmt.Errorf("failed to create base module: %w\n", err) + } + instance.access, err = access.New(instance) + if err != nil { + return fmt.Errorf("failed to create access module: %w\n", err) + } + instance.terminal, err = terminal.New(instance) + if err != nil { + return fmt.Errorf("failed to create terminal module: %w\n", err) + } + instance.cabin, err = cabin.New(instance) + if err != nil { + return fmt.Errorf("failed to create cabin module: %w\n", err) + } + module, err = New(instance) + if err != nil { + return fmt.Errorf("failed to create docks module: %w\n", err) + } + + // Start + err = instance.db.Start() + if err != nil { + return fmt.Errorf("failed to start db module: %w\n", err) + } + err = instance.config.Start() + if err != nil { + return fmt.Errorf("failed to start config module: %w\n", err) + } + err = instance.metrics.Start() + if err != nil { + return fmt.Errorf("failed to start metrics module: %w\n", err) + } + err = instance.rng.Start() + if err != nil { + return fmt.Errorf("failed to start rng module: %w\n", err) + } + err = instance.base.Start() + if err != nil { + return fmt.Errorf("failed to start base module: %w\n", err) + } + err = instance.access.Start() + if err != nil { + return fmt.Errorf("failed to start access module: %w\n", err) + } + err = instance.terminal.Start() + if err != nil { + return fmt.Errorf("failed to start terminal module: %w\n", err) + } + err = instance.cabin.Start() + if err != nil { + return fmt.Errorf("failed to start cabin module: %w\n", err) + } + err = module.Start() + if err != nil { + return fmt.Errorf("failed to start docks module: %w\n", err) + } + + m.Run() + return nil +} + +func TestMain(m *testing.M) { + if err := runTest(m); err != nil { + fmt.Printf("%s\n", err) + os.Exit(1) + } } diff --git a/spn/docks/op_capacity_test.go b/spn/docks/op_capacity_test.go index 1aaa14370..1d7e442fe 100644 --- a/spn/docks/op_capacity_test.go +++ b/spn/docks/op_capacity_test.go @@ -60,7 +60,7 @@ func testCapacityOp(t *testing.T, opts *CapacityTestOptions) { b.GrantPermission(terminal.IsCraneController) op, tErr := NewCapacityTestOp(a, opts) if tErr != nil { - t.Fatalf("failed to start op: %s", err) + t.Fatalf("failed to start op: %s", tErr) } // Wait for result and check error. diff --git a/spn/docks/op_latency_test.go b/spn/docks/op_latency_test.go index 7a0b4ec74..8dd53ab0c 100644 --- a/spn/docks/op_latency_test.go +++ b/spn/docks/op_latency_test.go @@ -35,7 +35,7 @@ func TestLatencyOp(t *testing.T) { b.GrantPermission(terminal.IsCraneController) op, tErr := NewLatencyTestOp(a) if tErr != nil { - t.Fatalf("failed to start op: %s", err) + t.Fatalf("failed to start op: %s", tErr) } // Wait for result and check error. diff --git a/spn/docks/terminal_expansion_test.go b/spn/docks/terminal_expansion_test.go index 415716ea3..7dfb38b81 100644 --- a/spn/docks/terminal_expansion_test.go +++ b/spn/docks/terminal_expansion_test.go @@ -102,7 +102,7 @@ func testExpansion( //nolint:maintidx,thelper panic(fmt.Sprintf("expansion test %s could not create crane1: %s", testID, err)) } crane1.ID = "c1" - err = crane1.Start(module.Ctx) + err = crane1.Start(module.mgr.Ctx()) if err != nil { panic(fmt.Sprintf("expansion test %s could not start crane1: %s", testID, err)) } @@ -116,7 +116,7 @@ func testExpansion( //nolint:maintidx,thelper panic(fmt.Sprintf("expansion test %s could not create crane2to1: %s", testID, err)) } crane2to1.ID = "c2to1" - err = crane2to1.Start(module.Ctx) + err = crane2to1.Start(module.mgr.Ctx()) if err != nil { panic(fmt.Sprintf("expansion test %s could not start crane2to1: %s", testID, err)) } @@ -130,7 +130,7 @@ func testExpansion( //nolint:maintidx,thelper panic(fmt.Sprintf("expansion test %s could not create crane2to3: %s", testID, err)) } crane2to3.ID = "c2to3" - err = crane2to3.Start(module.Ctx) + err = crane2to3.Start(module.mgr.Ctx()) if err != nil { panic(fmt.Sprintf("expansion test %s could not start crane2to3: %s", testID, err)) } @@ -144,7 +144,7 @@ func testExpansion( //nolint:maintidx,thelper panic(fmt.Sprintf("expansion test %s could not create crane3to2: %s", testID, err)) } crane3to2.ID = "c3to2" - err = crane3to2.Start(module.Ctx) + err = crane3to2.Start(module.mgr.Ctx()) if err != nil { panic(fmt.Sprintf("expansion test %s could not start crane3to2: %s", testID, err)) } @@ -158,7 +158,7 @@ func testExpansion( //nolint:maintidx,thelper panic(fmt.Sprintf("expansion test %s could not create crane3to4: %s", testID, err)) } crane3to4.ID = "c3to4" - err = crane3to4.Start(module.Ctx) + err = crane3to4.Start(module.mgr.Ctx()) if err != nil { panic(fmt.Sprintf("expansion test %s could not start crane3to4: %s", testID, err)) } @@ -172,7 +172,7 @@ func testExpansion( //nolint:maintidx,thelper panic(fmt.Sprintf("expansion test %s could not create crane4: %s", testID, err)) } crane4.ID = "c4" - err = crane4.Start(module.Ctx) + err = crane4.Start(module.mgr.Ctx()) if err != nil { panic(fmt.Sprintf("expansion test %s could not start crane4: %s", testID, err)) } diff --git a/spn/hub/hub_test.go b/spn/hub/hub_test.go index d715653e8..6f0cd60bb 100644 --- a/spn/hub/hub_test.go +++ b/spn/hub/hub_test.go @@ -1,19 +1,114 @@ package hub import ( + "fmt" "net" + "os" "testing" "github.com/stretchr/testify/assert" - _ "github.com/safing/portmaster/service/core/base" - "github.com/safing/portmaster/service/core/pmtesting" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/database/dbmodule" + "github.com/safing/portmaster/service/core/base" + "github.com/safing/portmaster/service/updates" ) +type testInstance struct { + db *dbmodule.DBModule + api *api.API + config *config.Config + updates *updates.Updates + base *base.Base +} + +func (stub *testInstance) Updates() *updates.Updates { + return stub.updates +} + +func (stub *testInstance) API() *api.API { + return stub.api +} + +func (stub *testInstance) Config() *config.Config { + return stub.config +} + +func (stub *testInstance) Base() *base.Base { + return stub.base +} + +func (stub *testInstance) Ready() bool { + return true +} + +func (stub *testInstance) Shutdown(exitCode int) {} + +func (stub *testInstance) SetCmdLineOperation(f func() error) {} + +func runTest(m *testing.M) error { + api.SetDefaultAPIListenAddress("0.0.0.0:8080") + ds, err := config.InitializeUnitTestDataroot("test-hub") + if err != nil { + return fmt.Errorf("failed to initialize dataroot: %w", err) + } + defer func() { _ = os.RemoveAll(ds) }() + + stub := &testInstance{} + // Init + stub.db, err = dbmodule.New(stub) + if err != nil { + return fmt.Errorf("failed to create database: %w", err) + } + stub.api, err = api.New(stub) + if err != nil { + return fmt.Errorf("failed to create api: %w", err) + } + stub.config, err = config.New(stub) + if err != nil { + return fmt.Errorf("failed to create config: %w", err) + } + stub.updates, err = updates.New(stub) + if err != nil { + return fmt.Errorf("failed to create updates: %w", err) + } + stub.base, err = base.New(stub) + if err != nil { + return fmt.Errorf("failed to base updates: %w", err) + } + + // Start + err = stub.db.Start() + if err != nil { + return fmt.Errorf("failed to start database: %w", err) + } + err = stub.api.Start() + if err != nil { + return fmt.Errorf("failed to start api: %w", err) + } + err = stub.config.Start() + if err != nil { + return fmt.Errorf("failed to start config: %w", err) + } + err = stub.updates.Start() + if err != nil { + return fmt.Errorf("failed to start updates: %w", err) + } + err = stub.base.Start() + if err != nil { + return fmt.Errorf("failed to start base: %w", err) + } + + m.Run() + return nil +} + func TestMain(m *testing.M) { - // TODO: We need the database module, so maybe set up a module for this package. - module := modules.Register("hub", nil, nil, nil, "base") - pmtesting.TestMain(m, module) + if err := runTest(m); err != nil { + fmt.Printf("%s", err) + os.Exit(1) + } } func TestEquality(t *testing.T) { diff --git a/spn/navigator/module_test.go b/spn/navigator/module_test.go index 4433835fb..31ffdf4af 100644 --- a/spn/navigator/module_test.go +++ b/spn/navigator/module_test.go @@ -1,13 +1,129 @@ package navigator import ( + "fmt" + "os" "testing" + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/database/dbmodule" "github.com/safing/portmaster/base/log" - "github.com/safing/portmaster/service/core/pmtesting" + "github.com/safing/portmaster/service/core/base" + "github.com/safing/portmaster/service/intel/geoip" + "github.com/safing/portmaster/service/updates" ) -func TestMain(m *testing.M) { +type testInstance struct { + db *dbmodule.DBModule + api *api.API + config *config.Config + updates *updates.Updates + base *base.Base + geoip *geoip.GeoIP +} + +func (stub *testInstance) Updates() *updates.Updates { + return stub.updates +} + +func (stub *testInstance) API() *api.API { + return stub.api +} + +func (stub *testInstance) Config() *config.Config { + return stub.config +} + +func (stub *testInstance) Base() *base.Base { + return stub.base +} + +func (stub *testInstance) Ready() bool { + return true +} + +func (stub *testInstance) Shutdown(exitCode int) {} + +func (stub *testInstance) SetCmdLineOperation(f func() error) {} + +func runTest(m *testing.M) error { + api.SetDefaultAPIListenAddress("0.0.0.0:8080") + ds, err := config.InitializeUnitTestDataroot("test-navigator") + if err != nil { + return fmt.Errorf("failed to initialize dataroot: %w", err) + } + defer func() { _ = os.RemoveAll(ds) }() + + stub := &testInstance{} log.SetLogLevel(log.DebugLevel) - pmtesting.TestMain(m, module) + + // Init + stub.db, err = dbmodule.New(stub) + if err != nil { + return fmt.Errorf("failed to create db: %w", err) + } + stub.api, err = api.New(stub) + if err != nil { + return fmt.Errorf("failed to create api: %w", err) + } + stub.config, err = config.New(stub) + if err != nil { + return fmt.Errorf("failed to create config: %w", err) + } + stub.updates, err = updates.New(stub) + if err != nil { + return fmt.Errorf("failed to create updates: %w", err) + } + stub.base, err = base.New(stub) + if err != nil { + return fmt.Errorf("failed to create base: %w", err) + } + stub.geoip, err = geoip.New(stub) + if err != nil { + return fmt.Errorf("failed to create geoip: %w", err) + } + module, err = New(stub) + if err != nil { + return fmt.Errorf("failed to create navigator module: %w", err) + } + // Start + err = stub.db.Start() + if err != nil { + return fmt.Errorf("failed to start db module: %w", err) + } + err = stub.api.Start() + if err != nil { + return fmt.Errorf("failed to start api: %w", err) + } + err = stub.config.Start() + if err != nil { + return fmt.Errorf("failed to start config: %w", err) + } + err = stub.updates.Start() + if err != nil { + return fmt.Errorf("failed to start updates: %w", err) + } + err = stub.base.Start() + if err != nil { + return fmt.Errorf("failed to start base module: %w", err) + } + err = stub.geoip.Start() + if err != nil { + return fmt.Errorf("failed to start geoip module: %w", err) + } + err = module.Start() + if err != nil { + return fmt.Errorf("failed to start navigator module: %w", err) + } + + m.Run() + return nil +} + +func TestMain(m *testing.M) { + if err := runTest(m); err != nil { + fmt.Printf("%s\n", err) + os.Exit(1) + } } diff --git a/spn/ships/http_shared_test.go b/spn/ships/http_shared_test.go index d48417e4c..190982643 100644 --- a/spn/ships/http_shared_test.go +++ b/spn/ships/http_shared_test.go @@ -8,10 +8,15 @@ import ( ) func TestSharedHTTP(t *testing.T) { //nolint:paralleltest // Test checks global state. + _, err := New(struct{}{}) + if err != nil { + t.Errorf("failed to create module ships: %s", err) + } + const testPort = 65100 // Register multiple handlers. - err := addHTTPHandler(testPort, "", ServeInfoPage) + err = addHTTPHandler(testPort, "", ServeInfoPage) require.NoError(t, err, "should be able to share http listener") err = addHTTPHandler(testPort, "/test", ServeInfoPage) require.NoError(t, err, "should be able to share http listener") diff --git a/spn/terminal/module_test.go b/spn/terminal/module_test.go index 1f07003d4..93ebeafa9 100644 --- a/spn/terminal/module_test.go +++ b/spn/terminal/module_test.go @@ -1,13 +1,122 @@ package terminal import ( + "fmt" + "os" "testing" - "github.com/safing/portmaster/service/core/pmtesting" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/database/dbmodule" + "github.com/safing/portmaster/base/metrics" + "github.com/safing/portmaster/base/rng" + "github.com/safing/portmaster/service/core/base" + "github.com/safing/portmaster/service/mgr" + "github.com/safing/portmaster/spn/cabin" "github.com/safing/portmaster/spn/conf" ) +type testInstance struct { + db *dbmodule.DBModule + config *config.Config + metrics *metrics.Metrics + rng *rng.Rng + base *base.Base + cabin *cabin.Cabin +} + +func (stub *testInstance) Config() *config.Config { + return stub.config +} + +func (stub *testInstance) Metrics() *metrics.Metrics { + return stub.metrics +} + +func (stub *testInstance) SPNGroup() *mgr.ExtendedGroup { + return nil +} + +func (stub *testInstance) Stopping() bool { + return false +} +func (stub *testInstance) SetCmdLineOperation(f func() error) {} + +func runTest(m *testing.M) error { + ds, err := config.InitializeUnitTestDataroot("test-terminal") + if err != nil { + return fmt.Errorf("failed to initialize dataroot: %w", err) + } + defer func() { _ = os.RemoveAll(ds) }() + + conf.EnablePublicHub(true) // Make hub config available. + + instance := &testInstance{} + instance.db, err = dbmodule.New(instance) + if err != nil { + return fmt.Errorf("failed to create database module: %w\n", err) + } + instance.config, err = config.New(instance) + if err != nil { + return fmt.Errorf("failed to create config module: %w\n", err) + } + instance.metrics, err = metrics.New(instance) + if err != nil { + return fmt.Errorf("failed to create metrics module: %w\n", err) + } + instance.rng, err = rng.New(instance) + if err != nil { + return fmt.Errorf("failed to create rng module: %w\n", err) + } + instance.base, err = base.New(instance) + if err != nil { + return fmt.Errorf("failed to create base module: %w\n", err) + } + instance.cabin, err = cabin.New(instance) + if err != nil { + return fmt.Errorf("failed to create cabin module: %w\n", err) + } + _, err = New(instance) + if err != nil { + fmt.Printf("failed to create module: %s\n", err) + os.Exit(0) + } + + // Start + err = instance.db.Start() + if err != nil { + return fmt.Errorf("failed to start db module: %w\n", err) + } + err = instance.config.Start() + if err != nil { + return fmt.Errorf("failed to start config module: %w\n", err) + } + err = instance.metrics.Start() + if err != nil { + return fmt.Errorf("failed to start metrics module: %w\n", err) + } + err = instance.rng.Start() + if err != nil { + return fmt.Errorf("failed to start rng module: %w\n", err) + } + err = instance.base.Start() + if err != nil { + return fmt.Errorf("failed to start base module: %w\n", err) + } + err = instance.cabin.Start() + if err != nil { + return fmt.Errorf("failed to start cabin module: %w\n", err) + } + err = module.Start() + if err != nil { + return fmt.Errorf("failed to start docks module: %w\n", err) + } + + m.Run() + return nil +} + func TestMain(m *testing.M) { - conf.EnablePublicHub(true) - pmtesting.TestMain(m, module) + if err := runTest(m); err != nil { + os.Exit(1) + } } diff --git a/spn/terminal/terminal_test.go b/spn/terminal/terminal_test.go index 1aeee14b3..3f08794e3 100644 --- a/spn/terminal/terminal_test.go +++ b/spn/terminal/terminal_test.go @@ -16,7 +16,7 @@ import ( func TestTerminals(t *testing.T) { t.Parallel() - identity, erro := cabin.CreateIdentity(module.Ctx, "test") + identity, erro := cabin.CreateIdentity(module.mgr.Ctx(), "test") if erro != nil { t.Fatalf("failed to create identity: %s", erro) } @@ -65,7 +65,7 @@ func testTerminals(t *testing.T, identity *cabin.Identity, terminalOpts *Termina var initData *container.Container var err *Error term1, initData, err = NewLocalTestTerminal( - module.Ctx, 127, "c1", dstHub, terminalOpts, createForwardingUpstream( + module.mgr.Ctx(), 127, "c1", dstHub, terminalOpts, createForwardingUpstream( t, "c1", "c2", func(msg *Msg) *Error { return term2.Deliver(msg) }, @@ -75,7 +75,7 @@ func testTerminals(t *testing.T, identity *cabin.Identity, terminalOpts *Termina t.Fatalf("failed to create local terminal: %s", err) } term2, _, err = NewRemoteTestTerminal( - module.Ctx, 127, "c2", identity, initData, createForwardingUpstream( + module.mgr.Ctx(), 127, "c2", identity, initData, createForwardingUpstream( t, "c2", "c1", func(msg *Msg) *Error { return term1.Deliver(msg) }, diff --git a/spn/terminal/testing.go b/spn/terminal/testing.go index 67c7dd8d9..d0a39aa4a 100644 --- a/spn/terminal/testing.go +++ b/spn/terminal/testing.go @@ -35,7 +35,7 @@ func NewLocalTestTerminal( if err != nil { return nil, nil, err } - // t.StartWorkers(module, "test terminal") + t.StartWorkers(module.mgr, "test terminal") return &TestTerminal{t}, initData, nil } @@ -54,7 +54,7 @@ func NewRemoteTestTerminal( if err != nil { return nil, nil, err } - // t.StartWorkers(module, "test terminal") + t.StartWorkers(module.mgr, "test terminal") return &TestTerminal{t}, initMsg, nil } @@ -136,38 +136,38 @@ func (t *TestTerminal) HandleAbandon(err *Error) (errorToSend *Error) { return } -// NewSimpleTestTerminalPair provides a simple conntected terminal pair for tests. +// NewSimpleTestTerminalPair provides a simple connected terminal pair for tests. func NewSimpleTestTerminalPair(delay time.Duration, delayQueueSize int, opts *TerminalOpts) (a, b *TestTerminal, err error) { - // if opts == nil { - // opts = &TerminalOpts{ - // Padding: defaultTestPadding, - // FlowControl: FlowControlDFQ, - // FlowControlSize: defaultTestQueueSize, - // } - // } - - // var initData *container.Container - // var tErr *Error - // a, initData, tErr = NewLocalTestTerminal( - // module.Ctx, 127, "a", nil, opts, UpstreamSendFunc(createDelayingTestForwardingFunc( - // "a", "b", delay, delayQueueSize, func(msg *Msg, timeout time.Duration) *Error { - // return b.Deliver(msg) - // }, - // )), - // ) - // if tErr != nil { - // return nil, nil, tErr.Wrap("failed to create local test terminal") - // } - // b, _, tErr = NewRemoteTestTerminal( - // module.Ctx, 127, "b", nil, initData, UpstreamSendFunc(createDelayingTestForwardingFunc( - // "b", "a", delay, delayQueueSize, func(msg *Msg, timeout time.Duration) *Error { - // return a.Deliver(msg) - // }, - // )), - // ) - // if tErr != nil { - // return nil, nil, tErr.Wrap("failed to create remote test terminal") - // } + if opts == nil { + opts = &TerminalOpts{ + Padding: defaultTestPadding, + FlowControl: FlowControlDFQ, + FlowControlSize: defaultTestQueueSize, + } + } + + var initData *container.Container + var tErr *Error + a, initData, tErr = NewLocalTestTerminal( + module.mgr.Ctx(), 127, "a", nil, opts, UpstreamSendFunc(createDelayingTestForwardingFunc( + "a", "b", delay, delayQueueSize, func(msg *Msg, timeout time.Duration) *Error { + return b.Deliver(msg) + }, + )), + ) + if tErr != nil { + return nil, nil, tErr.Wrap("failed to create local test terminal") + } + b, _, tErr = NewRemoteTestTerminal( + module.mgr.Ctx(), 127, "b", nil, initData, UpstreamSendFunc(createDelayingTestForwardingFunc( + "b", "a", delay, delayQueueSize, func(msg *Msg, timeout time.Duration) *Error { + return a.Deliver(msg) + }, + )), + ) + if tErr != nil { + return nil, nil, tErr.Wrap("failed to create remote test terminal") + } return a, b, nil } diff --git a/spn/unit/scheduler_test.go b/spn/unit/scheduler_test.go index 3e3ec6ba0..f2203015e 100644 --- a/spn/unit/scheduler_test.go +++ b/spn/unit/scheduler_test.go @@ -1,8 +1,9 @@ package unit import ( - "context" "testing" + + "github.com/safing/portmaster/service/mgr" ) func BenchmarkScheduler(b *testing.B) { @@ -10,21 +11,22 @@ func BenchmarkScheduler(b *testing.B) { // Create and start scheduler. s := NewScheduler(&SchedulerConfig{}) - ctx, cancel := context.WithCancel(context.Background()) - go func() { + m := mgr.New("unit-test") + m.Go("test", func(ctx *mgr.WorkerCtx) error { err := s.SlotScheduler(ctx) if err != nil { panic(err) } - }() - defer cancel() + return nil + }) + defer m.Cancel() // Init control structures. done := make(chan struct{}) finishedCh := make(chan struct{}) // Start workers. - for i := 0; i < workers; i++ { + for range workers { go func() { for { u := s.NewUnit() @@ -41,7 +43,7 @@ func BenchmarkScheduler(b *testing.B) { // Start benchmark. b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { <-finishedCh } b.StopTimer() diff --git a/spn/unit/unit_test.go b/spn/unit/unit_test.go index 8f5a5ac8e..a77ea899e 100644 --- a/spn/unit/unit_test.go +++ b/spn/unit/unit_test.go @@ -1,7 +1,6 @@ package unit import ( - "context" "fmt" "math" "math/rand" @@ -9,6 +8,7 @@ import ( "testing" "time" + "github.com/safing/portmaster/service/mgr" "github.com/stretchr/testify/assert" ) @@ -20,17 +20,19 @@ func TestUnit(t *testing.T) { //nolint:paralleltest size := 1000000 workers := 100 + m := mgr.New("unit-test") // Create and start scheduler. s := NewScheduler(&SchedulerConfig{}) s.StartDebugLog() - ctx, cancel := context.WithCancel(context.Background()) - go func() { - err := s.SlotScheduler(ctx) + // ctx, cancel := context.WithCancel(context.Background()) + m.Go("test", func(w *mgr.WorkerCtx) error { + err := s.SlotScheduler(w) if err != nil { panic(err) } - }() - defer cancel() + return nil + }) + defer m.Cancel() // Create 10 workers. var wg sync.WaitGroup @@ -96,7 +98,7 @@ func TestUnit(t *testing.T) { //nolint:paralleltest ) // Shutdown - cancel() + m.Cancel() time.Sleep(s.config.SlotDuration * 10) // Check if scheduler shut down correctly. From ad64e0443f979582b15c49754c5e7068c4591ba8 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 26 Jul 2024 14:52:50 +0200 Subject: [PATCH 38/56] Review new module system and fix minor issues --- base/api/authentication.go | 2 +- base/api/main.go | 2 - base/api/module.go | 5 +- base/api/modules.go | 47 ------------------- base/runtime/modules_integration.go | 1 + base/template/module.go | 1 + base/utils/debug/debug.go | 15 ------ cmds/observation-hub/observe.go | 1 + service/compat/module.go | 13 +++-- service/core/api.go | 3 +- service/core/core.go | 10 ---- service/firewall/interception/module.go | 2 - .../firewall/interception/nfqueue_linux.go | 5 -- service/instance.go | 5 +- service/intel/customlists/lists.go | 4 +- service/intel/filterlists/module.go | 2 - service/intel/geoip/database.go | 2 +- service/intel/module.go | 12 ----- service/mgr/states.go | 7 +++ service/nameserver/module.go | 12 ----- service/netquery/module_api.go | 24 ---------- service/network/netutils/ip.go | 4 ++ service/profile/migrations.go | 4 +- spn/access/module.go | 2 - spn/captain/module.go | 21 --------- spn/terminal/testing.go | 4 ++ 26 files changed, 40 insertions(+), 170 deletions(-) delete mode 100644 base/api/modules.go delete mode 100644 service/intel/module.go diff --git a/base/api/authentication.go b/base/api/authentication.go index 64ae3538a..73e659f53 100644 --- a/base/api/authentication.go +++ b/base/api/authentication.go @@ -122,7 +122,7 @@ type AuthenticatedHandler interface { // SetAuthenticator sets an authenticator function for the API endpoint. If none is set, all requests will be permitted. func SetAuthenticator(fn AuthenticatorFunc) error { - if module.online { + if module.online.Load() { return ErrAuthenticationImmutable } diff --git a/base/api/main.go b/base/api/main.go index 06875a399..293997716 100644 --- a/base/api/main.go +++ b/base/api/main.go @@ -19,8 +19,6 @@ var ( ) func init() { - // module = modules.Register("api", prep, start, stop, "database", "config") - flag.BoolVar(&exportEndpoints, "export-api-endpoints", false, "export api endpoint registry and exit") } diff --git a/base/api/module.go b/base/api/module.go index 65f01b629..7502efa9b 100644 --- a/base/api/module.go +++ b/base/api/module.go @@ -13,7 +13,7 @@ type API struct { mgr *mgr.Manager instance instance - online bool + online atomic.Bool } func (api *API) Manager() *mgr.Manager { @@ -26,12 +26,13 @@ func (api *API) Start() error { return err } - module.online = true + api.online.Store(true) return nil } // Stop stops the module. func (api *API) Stop() error { + defer api.online.Store(false) return stop() } diff --git a/base/api/modules.go b/base/api/modules.go deleted file mode 100644 index aa9190e86..000000000 --- a/base/api/modules.go +++ /dev/null @@ -1,47 +0,0 @@ -package api - -import ( - "time" -) - -// ModuleHandler specifies the interface for API endpoints that are bound to a module. -// type ModuleHandler interface { -// BelongsTo() *modules.Module -// } - -const ( - moduleCheckMaxWait = 10 * time.Second - moduleCheckTickDuration = 500 * time.Millisecond -) - -// moduleIsReady checks if the given module is online and http requests can be -// sent its way. If the module is not online already, it will wait for a short -// duration for it to come online. -// func moduleIsReady(m *modules.Module) (ok bool) { -// // Check if we are given a module. -// if m == nil { -// // If no module is given, we assume that the handler has not been linked to -// // a module, and we can safely continue with the request. -// return true -// } - -// // Check if the module is online. -// if m.Online() { -// return true -// } - -// // Check if the module will come online. -// if m.OnlineSoon() { -// var i time.Duration -// for i = 0; i < moduleCheckMaxWait; i += moduleCheckTickDuration { -// // Wait a little. -// time.Sleep(moduleCheckTickDuration) -// // Check if module is now online. -// if m.Online() { -// return true -// } -// } -// } - -// return false -// } diff --git a/base/runtime/modules_integration.go b/base/runtime/modules_integration.go index de85d0b9d..fe4def047 100644 --- a/base/runtime/modules_integration.go +++ b/base/runtime/modules_integration.go @@ -16,6 +16,7 @@ func startModulesIntegration() (err error) { return err } + // FIXME(Daniel): What did this do? Do we need it? // if !modules.SetEventSubscriptionFunc(pushModuleEvent) { // log.Warningf("runtime: failed to register the modules event subscription function") // } diff --git a/base/template/module.go b/base/template/module.go index bbf3b71ad..2aa06e68b 100644 --- a/base/template/module.go +++ b/base/template/module.go @@ -1,5 +1,6 @@ package template +// FIXME: update template to new best way to do it. // import ( // "context" // "time" diff --git a/base/utils/debug/debug.go b/base/utils/debug/debug.go index 0446f9d90..ed8afbc3b 100644 --- a/base/utils/debug/debug.go +++ b/base/utils/debug/debug.go @@ -110,21 +110,6 @@ func (di *Info) AddGoroutineStack() { ) } -// AddLastReportedModuleError adds the last reported module error, if one exists. -func (di *Info) AddLastReportedModuleError() { - // me := modules.GetLastReportedError() - // if me == nil { - // di.AddSection("No Module Error", NoFlags) - // return - // } - - // di.AddSection( - // fmt.Sprintf("%s Module Error", strings.Title(me.ModuleName)), //nolint:staticcheck - // UseCodeSection, - // me.Format(), - // ) -} - // AddLastUnexpectedLogs adds the last 10 unexpected log lines, if any. func (di *Info) AddLastUnexpectedLogs() { lines := log.GetLastUnexpectedLogs() diff --git a/cmds/observation-hub/observe.go b/cmds/observation-hub/observe.go index cec4c687c..47955c2af 100644 --- a/cmds/observation-hub/observe.go +++ b/cmds/observation-hub/observe.go @@ -12,6 +12,7 @@ import ( diff "github.com/r3labs/diff/v3" "golang.org/x/exp/slices" + "github.com/safing/portbase/modules" "github.com/safing/portmaster/base/database" "github.com/safing/portmaster/base/database/query" "github.com/safing/portmaster/base/log" diff --git a/service/compat/module.go b/service/compat/module.go index 5726b355d..cc40e29aa 100644 --- a/service/compat/module.go +++ b/service/compat/module.go @@ -17,8 +17,10 @@ type Compat struct { mgr *mgr.Manager instance instance - selfcheckWorkerMgr *mgr.WorkerMgr - states *mgr.StateMgr + selfcheckWorkerMgr *mgr.WorkerMgr + cleanNotifyThresholdWorkerMgr *mgr.WorkerMgr + + states *mgr.StateMgr } func (u *Compat) Manager() *mgr.Manager { @@ -76,9 +78,9 @@ func start() error { startNotify() selfcheckNetworkChangedFlag.Refresh() - module.selfcheckWorkerMgr = module.mgr.Repeat("compatibility self-check", 5*time.Minute, selfcheckTaskFunc).Delay(selfcheckTaskRetryAfter) + module.selfcheckWorkerMgr.Repeat(5 * time.Minute).Delay(selfcheckTaskRetryAfter) + module.cleanNotifyThresholdWorkerMgr.Repeat(1 * time.Hour) - _ = module.mgr.Repeat("clean notify thresholds", 1*time.Hour, cleanNotifyThreshold) module.instance.NetEnv().EventNetworkChange.AddCallback("trigger compat self-check", func(_ *mgr.WorkerCtx, _ struct{}) (bool, error) { module.selfcheckWorkerMgr.Delay(selfcheckTaskRetryAfter) return false, nil @@ -163,6 +165,9 @@ func New(instance instance) (*Compat, error) { mgr: m, instance: instance, + selfcheckWorkerMgr: m.NewWorkerMgr("compatibility self-check", selfcheckTaskFunc, nil), + cleanNotifyThresholdWorkerMgr: m.NewWorkerMgr("clean notify thresholds", cleanNotifyThreshold, nil), + states: mgr.NewStateMgr(m), } if err := prep(); err != nil { diff --git a/service/core/api.go b/service/core/api.go index ea03ac420..5893c7066 100644 --- a/service/core/api.go +++ b/service/core/api.go @@ -139,8 +139,7 @@ func debugInfo(ar *api.Request) (data []byte, err error) { di.AddVersionInfo() di.AddPlatformInfo(ar.Context()) - // Errors and unexpected logs. - di.AddLastReportedModuleError() + // Unexpected logs. di.AddLastUnexpectedLogs() // Status Information from various modules. diff --git a/service/core/core.go b/service/core/core.go index 983e9eee2..06c4c5630 100644 --- a/service/core/core.go +++ b/service/core/core.go @@ -45,16 +45,6 @@ func (c *Core) Stop() error { var disableShutdownEvent bool func init() { - // module = modules.Register("core", prep, start, nil, "base", "subsystems", "status", "updates", "api", "notifications", "ui", "netenv", "network", "netquery", "interception", "compat", "broadcasts", "sync") - // subsystems.Register( - // "core", - // "Core", - // "Base Structure and System Integration", - // module, - // "config:core/", - // nil, - // ) - flag.BoolVar( &disableShutdownEvent, "disable-shutdown-event", diff --git a/service/firewall/interception/module.go b/service/firewall/interception/module.go index da0a727fd..72a66abe4 100644 --- a/service/firewall/interception/module.go +++ b/service/firewall/interception/module.go @@ -39,8 +39,6 @@ var ( func init() { flag.BoolVar(&disableInterception, "disable-interception", false, "disable packet interception; this breaks a lot of functionality") - - // module = modules.Register("interception", prep, start, stop, "base", "updates", "network", "notifications", "profiles") } // Start starts the interception. diff --git a/service/firewall/interception/nfqueue_linux.go b/service/firewall/interception/nfqueue_linux.go index bff94fe28..2f83480fe 100644 --- a/service/firewall/interception/nfqueue_linux.go +++ b/service/firewall/interception/nfqueue_linux.go @@ -258,30 +258,25 @@ func StartNfqueueInterception(packets chan<- packet.Packet) (err error) { err = activateNfqueueFirewall() if err != nil { - // _ = StopNfqueueInterception() return fmt.Errorf("could not initialize nfqueue: %w", err) } out4Queue, err = nfq.New(17040, false) if err != nil { - // _ = StopNfqueueInterception() return fmt.Errorf("nfqueue(IPv4, out): %w", err) } in4Queue, err = nfq.New(17140, false) if err != nil { - // _ = StopNfqueueInterception() return fmt.Errorf("nfqueue(IPv4, in): %w", err) } if netenv.IPv6Enabled() { out6Queue, err = nfq.New(17060, true) if err != nil { - // _ = StopNfqueueInterception() return fmt.Errorf("nfqueue(IPv6, out): %w", err) } in6Queue, err = nfq.New(17160, true) if err != nil { - // _ = StopNfqueueInterception() return fmt.Errorf("nfqueue(IPv6, in): %w", err) } } else { diff --git a/service/instance.go b/service/instance.go index d53cbaec2..14ac54981 100644 --- a/service/instance.go +++ b/service/instance.go @@ -46,7 +46,7 @@ import ( "github.com/safing/portmaster/spn/terminal" ) -// Instance is an instance of a portmaste service. +// Instance is an instance of a Portmaster service. type Instance struct { ctx context.Context cancelCtx context.CancelFunc @@ -100,7 +100,7 @@ type Instance struct { CommandLineOperation func() error } -// New returns a new portmaster service instance. +// New returns a new Portmaster service instance. func New(svcCfg *ServiceConfig) (*Instance, error) { // Create instance to pass it to modules. instance := &Instance{} @@ -313,6 +313,7 @@ func New(svcCfg *ServiceConfig) (*Instance, error) { } func (i *Instance) SetSleep(enabled bool) { + // FIXME(Daniel): Use a loop and a interface check to set sleep on all supported modules. i.metrics.SetSleep(enabled) i.network.SetSleep(enabled) i.captain.SetSleep(enabled) diff --git a/service/intel/customlists/lists.go b/service/intel/customlists/lists.go index 797d0723c..95fb661c4 100644 --- a/service/intel/customlists/lists.go +++ b/service/intel/customlists/lists.go @@ -87,6 +87,8 @@ func parseFile(filePath string) error { Type: mgr.StateTypeWarning, }) return err + } else { + module.states.Remove(parseWarningNotificationID) } defer func() { _ = file.Close() }() @@ -140,8 +142,6 @@ func parseFile(filePath string) error { len(autonomousSystemsFilterList), len(countryCodesFilterList))) - module.states.Remove(parseWarningNotificationID) - return nil } diff --git a/service/intel/filterlists/module.go b/service/intel/filterlists/module.go index c529c207b..92f6576e1 100644 --- a/service/intel/filterlists/module.go +++ b/service/intel/filterlists/module.go @@ -54,8 +54,6 @@ var ( func init() { ignoreNetEnvEvents.Set() - - // module = modules.Register("filterlists", prep, start, stop, "base", "updates") } func prep() error { diff --git a/service/intel/geoip/database.go b/service/intel/geoip/database.go index 6aee3d944..5f0258a73 100644 --- a/service/intel/geoip/database.go +++ b/service/intel/geoip/database.go @@ -148,7 +148,7 @@ func (upd *updateWorker) triggerUpdate() { func (upd *updateWorker) start() { upd.once.Do(func() { - module.mgr.Delay("geoip-updater", time.Second*10, upd.run) + module.mgr.Go("geoip-updater", upd.run) }) } diff --git a/service/intel/module.go b/service/intel/module.go deleted file mode 100644 index 41a7386aa..000000000 --- a/service/intel/module.go +++ /dev/null @@ -1,12 +0,0 @@ -package intel - -// import ( -// _ "github.com/safing/portmaster/service/intel/customlists" -// ) - -// Module of this package. Export needed for testing of the endpoints package. -// var Module *modules.Module - -// func init() { -// Module = modules.Register("intel", nil, nil, nil, "geoip", "filterlists", "customlists") -// } diff --git a/service/mgr/states.go b/service/mgr/states.go index a4228522d..9049bb298 100644 --- a/service/mgr/states.go +++ b/service/mgr/states.go @@ -125,6 +125,13 @@ func (m *StateMgr) Remove(id string) { m.statesLock.Lock() defer m.statesLock.Unlock() + // Quick check if slice is empty. + // It is a common pattern to remove a state when no error was encountered at + // a critical operation. This means that StateMgr.Remove will be called often. + if len(m.states) == 0 { + return + } + var entryRemoved bool m.states = slices.DeleteFunc(m.states, func(s State) bool { if s.ID == id { diff --git a/service/nameserver/module.go b/service/nameserver/module.go index 69fad7f28..b9a855ff5 100644 --- a/service/nameserver/module.go +++ b/service/nameserver/module.go @@ -52,18 +52,6 @@ var ( eventIDListenerFailed = "nameserver:listener-failed" ) -func init() { - // module = modules.Register("nameserver", prep, start, stop, "core", "resolver") - // subsystems.Register( - // "dns", - // "Secure DNS", - // "DNS resolver with scoping and DNS-over-TLS", - // module, - // "config:dns/", - // nil, - // ) -} - func prep() error { return registerConfig() } diff --git a/service/netquery/module_api.go b/service/netquery/module_api.go index e5250144a..9dde73002 100644 --- a/service/netquery/module_api.go +++ b/service/netquery/module_api.go @@ -33,30 +33,6 @@ type NetQuery struct { feed chan *network.Connection } -// DefaultModule is the default netquery module. -func init() { - // DefaultModule = new(module) - - // DefaultModule.Module = modules.Register( - // "netquery", - // DefaultModule.prepare, - // DefaultModule.start, - // DefaultModule.stop, - // "api", - // "network", - // "database", - // ) - - // subsystems.Register( - // "history", - // "Network History", - // "Keep Network History Data", - // DefaultModule.Module, - // "config:history/", - // nil, - // ) -} - func (nq *NetQuery) prepare() error { var err error diff --git a/service/network/netutils/ip.go b/service/network/netutils/ip.go index c77cc6b92..af316af21 100644 --- a/service/network/netutils/ip.go +++ b/service/network/netutils/ip.go @@ -84,6 +84,10 @@ func GetIPScope(ip net.IP) IPScope { //nolint:gocognit } } else if len(ip) == net.IPv6len { // IPv6 + + // TODO: Add IPv6 RFC5771 test / doc networks + // 2001:db8::/32 + // 3fff::/20 switch { case ip.Equal(net.IPv6zero): return Invalid diff --git a/service/profile/migrations.go b/service/profile/migrations.go index d081a3fb3..b9c813537 100644 --- a/service/profile/migrations.go +++ b/service/profile/migrations.go @@ -131,7 +131,7 @@ func migrateIcons(ctx context.Context, _, to *version.Version, db *database.Inte // Normally, an icon migration would not be such a big error, but this is a test // run for the profile IDs and we absolutely need to know if anything went wrong. module.states.Add(mgr.State{ - ID: "migration-failed", + ID: "migration-failed-icons", Name: "Profile Migration Failed", Message: fmt.Sprintf("Failed to migrate icons of %d profiles (out of %d pending). The last error was: %s\n\nPlease restart Portmaster to try the migration again.", failed, total, lastErr), Type: mgr.StateTypeError, @@ -220,7 +220,7 @@ func migrateToDerivedIDs(ctx context.Context, _, to *version.Version, db *databa // Log migration failure and try again next time. if lastErr != nil { module.states.Add(mgr.State{ - ID: "migration-failed", + ID: "migration-failed-derived-IDs", Name: "Profile Migration Failed", Message: fmt.Sprintf("Failed to migrate profile IDs of %d profiles (out of %d pending). The last error was: %s\n\nPlease restart Portmaster to try the migration again.", failed, total, lastErr), Type: mgr.StateTypeError, diff --git a/spn/access/module.go b/spn/access/module.go index c407518a2..4a318c793 100644 --- a/spn/access/module.go +++ b/spn/access/module.go @@ -41,8 +41,6 @@ var ( module *Access shimLoaded atomic.Bool - // accountUpdateTask *modules.Task - tokenIssuerIsFailing = abool.New() tokenIssuerRetryDuration = 10 * time.Minute diff --git a/spn/captain/module.go b/spn/captain/module.go index 0c3eb906e..4864f37ab 100644 --- a/spn/captain/module.go +++ b/spn/captain/module.go @@ -62,27 +62,6 @@ func (c *Captain) SetSleep(enabled bool) { } } -func init() { - // subsystems.Register( - // "spn", - // "SPN", - // "Safing Privacy Network", - // module, - // "config:spn/", - // &config.Option{ - // Name: "SPN Module", - // Key: CfgOptionEnableSPNKey, - // Description: "Start the Safing Privacy Network module. If turned off, the SPN is fully disabled on this device.", - // OptType: config.OptTypeBool, - // DefaultValue: false, - // Annotations: config.Annotations{ - // config.DisplayOrderAnnotation: cfgOptionEnableSPNOrder, - // config.CategoryAnnotation: "General", - // }, - // }, - // ) -} - func prep() error { // Check if we can parse the bootstrap hub flag. if err := prepBootstrapHubFlag(); err != nil { diff --git a/spn/terminal/testing.go b/spn/terminal/testing.go index 67c7dd8d9..27823a9b5 100644 --- a/spn/terminal/testing.go +++ b/spn/terminal/testing.go @@ -35,6 +35,7 @@ func NewLocalTestTerminal( if err != nil { return nil, nil, err } + // FIXME: We need this! // t.StartWorkers(module, "test terminal") return &TestTerminal{t}, initData, nil @@ -54,6 +55,7 @@ func NewRemoteTestTerminal( if err != nil { return nil, nil, err } + // FIXME: We need this! // t.StartWorkers(module, "test terminal") return &TestTerminal{t}, initMsg, nil @@ -138,6 +140,8 @@ func (t *TestTerminal) HandleAbandon(err *Error) (errorToSend *Error) { // NewSimpleTestTerminalPair provides a simple conntected terminal pair for tests. func NewSimpleTestTerminalPair(delay time.Duration, delayQueueSize int, opts *TerminalOpts) (a, b *TestTerminal, err error) { + // FIXME: I think we need this? + // if opts == nil { // opts = &TerminalOpts{ // Padding: defaultTestPadding, From 4980f33e413c31bfeac75786ff8f0b051c577678 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 26 Jul 2024 16:33:12 +0200 Subject: [PATCH 39/56] Push shutdown and restart events again via API --- base/api/endpoints_debug.go | 1 - base/runtime/module.go | 5 ---- service/core/core.go | 4 +++ .../core/events.go | 25 +++++++++++++------ service/network/api.go | 3 +-- 5 files changed, 22 insertions(+), 16 deletions(-) rename base/runtime/modules_integration.go => service/core/events.go (65%) diff --git a/base/api/endpoints_debug.go b/base/api/endpoints_debug.go index d2db39601..55865d9c5 100644 --- a/base/api/endpoints_debug.go +++ b/base/api/endpoints_debug.go @@ -241,7 +241,6 @@ func debugInfo(ar *Request) (data []byte, err error) { // Add debug information. di.AddVersionInfo() di.AddPlatformInfo(ar.Context()) - di.AddLastReportedModuleError() di.AddLastUnexpectedLogs() di.AddGoroutineStack() diff --git a/base/runtime/module.go b/base/runtime/module.go index ac46d02aa..21b7399c7 100644 --- a/base/runtime/module.go +++ b/base/runtime/module.go @@ -2,7 +2,6 @@ package runtime import ( "errors" - "fmt" "sync/atomic" "github.com/safing/portmaster/base/database" @@ -37,10 +36,6 @@ func (r *Runtime) Start() error { return err } - if err := startModulesIntegration(); err != nil { - return fmt.Errorf("failed to start modules integration: %w", err) - } - return nil } diff --git a/service/core/core.go b/service/core/core.go index 06c4c5630..9d6ab2305 100644 --- a/service/core/core.go +++ b/service/core/core.go @@ -64,6 +64,10 @@ func prep() error { return err } + if err := initModulesIntegration(); err != nil { + return err + } + return nil } diff --git a/base/runtime/modules_integration.go b/service/core/events.go similarity index 65% rename from base/runtime/modules_integration.go rename to service/core/events.go index fe4def047..362024d0e 100644 --- a/base/runtime/modules_integration.go +++ b/service/core/events.go @@ -1,4 +1,4 @@ -package runtime +package core import ( "fmt" @@ -6,20 +6,29 @@ import ( "github.com/safing/portmaster/base/database" "github.com/safing/portmaster/base/database/record" + "github.com/safing/portmaster/base/runtime" + "github.com/safing/portmaster/service/mgr" ) var modulesIntegrationUpdatePusher func(...record.Record) -func startModulesIntegration() (err error) { - modulesIntegrationUpdatePusher, err = Register("modules/", &ModulesIntegration{}) +func initModulesIntegration() (err error) { + modulesIntegrationUpdatePusher, err = runtime.Register("modules/", &ModulesIntegration{}) if err != nil { return err } - // FIXME(Daniel): What did this do? Do we need it? - // if !modules.SetEventSubscriptionFunc(pushModuleEvent) { - // log.Warningf("runtime: failed to register the modules event subscription function") - // } + // Push events via API. + module.EventRestart.AddCallback("expose restart event", func(wc *mgr.WorkerCtx, s struct{}) (cancel bool, err error) { + // Send event as runtime:modules/core/event/restart + pushModuleEvent("core", "restart", false, nil) + return false, nil + }) + module.EventShutdown.AddCallback("expose shutdown event", func(wc *mgr.WorkerCtx, s struct{}) (cancel bool, err error) { + // Send event as runtime:modules/core/event/shutdown + pushModuleEvent("core", "shutdown", false, nil) + return false, nil + }) return nil } @@ -33,7 +42,7 @@ type ModulesIntegration struct{} // the record passed to Set is prefixed with the key used // to register the value provider. func (mi *ModulesIntegration) Set(record.Record) (record.Record, error) { - return nil, ErrReadOnly + return nil, runtime.ErrReadOnly } // Get should return one or more records that match keyOrPrefix. diff --git a/service/network/api.go b/service/network/api.go index 78f7f7511..82b11ad0e 100644 --- a/service/network/api.go +++ b/service/network/api.go @@ -77,8 +77,7 @@ func debugInfo(ar *api.Request) (data []byte, err error) { di.AddVersionInfo() di.AddPlatformInfo(ar.Context()) - // Errors and unexpected logs. - di.AddLastReportedModuleError() + // Unexpected logs. di.AddLastUnexpectedLogs() // Network Connections. From 259845552ee8d01eddb67f32db761b316275fdd3 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 26 Jul 2024 16:40:45 +0200 Subject: [PATCH 40/56] Set sleep mode via interface --- service/instance.go | 20 ++++++++++++++++---- service/mgr/group.go | 9 +++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/service/instance.go b/service/instance.go index 14ac54981..aa9adb4be 100644 --- a/service/instance.go +++ b/service/instance.go @@ -312,11 +312,23 @@ func New(svcCfg *ServiceConfig) (*Instance, error) { return instance, nil } +// SleepyModule is an interface for modules that can enter some sort of sleep mode. +type SleepyModule interface { + SetSleep(enabled bool) +} + +// SetSleep sets sleep mode on all modules that satisfy the SleepyModule interface. func (i *Instance) SetSleep(enabled bool) { - // FIXME(Daniel): Use a loop and a interface check to set sleep on all supported modules. - i.metrics.SetSleep(enabled) - i.network.SetSleep(enabled) - i.captain.SetSleep(enabled) + for _, module := range i.serviceGroup.Modules() { + if sm, ok := module.(SleepyModule); ok { + sm.SetSleep(enabled) + } + } + for _, module := range i.SpnGroup.Modules() { + if sm, ok := module.(SleepyModule); ok { + sm.SetSleep(enabled) + } + } } // Database returns the database module. diff --git a/service/mgr/group.go b/service/mgr/group.go index 51d1fab17..b52798db8 100644 --- a/service/mgr/group.go +++ b/service/mgr/group.go @@ -235,6 +235,15 @@ func (g *Group) AddStatesCallback(callbackName string, callback EventCallbackFun } } +// Modules returns a copy of the modules. +func (g *Group) Modules() []Module { + copied := make([]Module, 0, len(g.modules)) + for _, gm := range g.modules { + copied = append(copied, gm.module) + } + return copied +} + // RunModules is a simple wrapper function to start modules and stop them again // when the given context is canceled. func RunModules(ctx context.Context, modules ...Module) error { From a3703c2a9a79c4f20bf7c9032ad04e8c6d70604d Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 26 Jul 2024 16:59:32 +0200 Subject: [PATCH 41/56] Update example/template module --- base/template/module.go | 96 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 base/template/module.go diff --git a/base/template/module.go b/base/template/module.go new file mode 100644 index 000000000..a789eaf06 --- /dev/null +++ b/base/template/module.go @@ -0,0 +1,96 @@ +package template + +import ( + "time" + + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/service/mgr" +) + +// Template showcases the usage of the module system. +type Template struct { + i instance + m *mgr.Manager + states *mgr.StateMgr + + EventRecordAdded *mgr.EventMgr[string] + EventRecordDeleted *mgr.EventMgr[string] + + specialWorkerMgr *mgr.WorkerMgr +} + +type instance interface{} + +// New returns a new template. +func New(instance instance) (*Template, error) { + m := mgr.New("template") + t := &Template{ + i: instance, + m: m, + states: m.NewStateMgr(), + + EventRecordAdded: mgr.NewEventMgr[string]("record added", m), + EventRecordDeleted: mgr.NewEventMgr[string]("record deleted", m), + + specialWorkerMgr: m.NewWorkerMgr("special worker", serviceWorker, nil), + } + + // register options + err := config.Register(&config.Option{ + Name: "language", + Key: "template/language", + Description: "Sets the language for the template [TEMPLATE]", + OptType: config.OptTypeString, + ExpertiseLevel: config.ExpertiseLevelUser, // default + ReleaseLevel: config.ReleaseLevelStable, // default + RequiresRestart: false, // default + DefaultValue: "en", + ValidationRegex: "^[a-z]{2}$", + }) + if err != nil { + return nil, err + } + + return t, nil +} + +// Manager returns the module manager. +func (t *Template) Manager() *mgr.Manager { + return t.m +} + +// States returns the module states. +func (t *Template) States() *mgr.StateMgr { + return t.states +} + +// Start starts the module. +func (t *Template) Start() error { + t.m.Go("worker", serviceWorker) + t.specialWorkerMgr.Delay(10 * time.Minute) + + return nil +} + +// Stop stops the module. +func Stop() error { + return nil +} + +func serviceWorker(w *mgr.WorkerCtx) error { + for { + select { + case <-time.After(1 * time.Second): + err := do() + if err != nil { + return err + } + case <-w.Done(): + return nil + } + } +} + +func do() error { + return nil +} From ecd2186fb86a3e6184987c1c1ccd612fa0212018 Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov Date: Tue, 30 Jul 2024 10:41:21 +0300 Subject: [PATCH 42/56] [WIP] Fix spn/cabin unit test --- cmds/observation-hub/observe.go | 1 - spn/cabin/module_test.go | 55 ++++++++++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/cmds/observation-hub/observe.go b/cmds/observation-hub/observe.go index 2e76ddd6a..76f0c0a96 100644 --- a/cmds/observation-hub/observe.go +++ b/cmds/observation-hub/observe.go @@ -12,7 +12,6 @@ import ( diff "github.com/r3labs/diff/v3" "golang.org/x/exp/slices" - "github.com/safing/portbase/modules" "github.com/safing/portmaster/base/database" "github.com/safing/portmaster/base/database/query" "github.com/safing/portmaster/base/log" diff --git a/spn/cabin/module_test.go b/spn/cabin/module_test.go index 13387a1d6..bf342849d 100644 --- a/spn/cabin/module_test.go +++ b/spn/cabin/module_test.go @@ -5,13 +5,21 @@ import ( "os" "testing" + "github.com/safing/portmaster/base/api" "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/database/dbmodule" + "github.com/safing/portmaster/base/rng" + "github.com/safing/portmaster/service/core/base" "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/conf" ) type testInstance struct { + db *dbmodule.DBModule + api *api.API config *config.Config + rng *rng.Rng + base *base.Base } func (stub *testInstance) Config() *config.Config { @@ -25,23 +33,68 @@ func (stub *testInstance) SPNGroup() *mgr.ExtendedGroup { func (stub *testInstance) Stopping() bool { return false } + +func (stub *testInstance) Ready() bool { + return true +} func (stub *testInstance) SetCmdLineOperation(f func() error) {} func runTest(m *testing.M) error { + api.SetDefaultAPIListenAddress("0.0.0.0:8080") + // Initialize dataroot + ds, err := config.InitializeUnitTestDataroot("test-cabin") + if err != nil { + return fmt.Errorf("failed to initialize dataroot: %w", err) + } + defer func() { _ = os.RemoveAll(ds) }() + + // Init instance := &testInstance{} - var err error + instance.db, err = dbmodule.New(instance) + if err != nil { + return fmt.Errorf("failed to create database: %w", err) + } + instance.api, err = api.New(instance) + if err != nil { + return fmt.Errorf("failed to create api: %w", err) + } instance.config, err = config.New(instance) if err != nil { return fmt.Errorf("failed to create config module: %w", err) } + instance.rng, err = rng.New(instance) + if err != nil { + return fmt.Errorf("failed to create rng module: %w", err) + } + instance.base, err = base.New(instance) + if err != nil { + return fmt.Errorf("failed to create base module: %w", err) + } module, err = New(struct{}{}) if err != nil { return fmt.Errorf("failed to create cabin module: %w", err) } + // Start + err = instance.db.Start() + if err != nil { + return fmt.Errorf("failed to start database: %w", err) + } + err = instance.api.Start() + if err != nil { + return fmt.Errorf("failed to start api: %w", err) + } err = instance.config.Start() if err != nil { return fmt.Errorf("failed to start config module: %w", err) } + err = instance.rng.Start() + if err != nil { + return fmt.Errorf("failed to start rng module: %w", err) + } + err = instance.base.Start() + if err != nil { + return fmt.Errorf("failed to start base module: %w", err) + } err = module.Start() if err != nil { return fmt.Errorf("failed to start cabin module: %w", err) From 018e640363511f925061ae2f4e50179e684a0b08 Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 30 Jul 2024 17:05:05 +0200 Subject: [PATCH 43/56] Remove deprecated UI elements --- .../angular/src/app/pages/spn/map.service.ts | 3 --- .../connection-details/conn-details.html | 9 ------- .../netquery/connection-helper.service.ts | 24 ------------------- .../netquery/connection-row/conn-row.html | 2 +- .../netquery/connection-row/conn-row.ts | 15 ------------ 5 files changed, 1 insertion(+), 52 deletions(-) diff --git a/desktop/angular/src/app/pages/spn/map.service.ts b/desktop/angular/src/app/pages/spn/map.service.ts index da8041a99..b8c0a65ae 100644 --- a/desktop/angular/src/app/pages/spn/map.service.ts +++ b/desktop/angular/src/app/pages/spn/map.service.ts @@ -30,9 +30,6 @@ export interface MapPin { // whether the pin has any known issues hasIssues: boolean; - - // FIXME: remove me - collapsed?: boolean; } @Injectable({ providedIn: 'root' }) diff --git a/desktop/angular/src/app/shared/netquery/connection-details/conn-details.html b/desktop/angular/src/app/shared/netquery/connection-details/conn-details.html index 087e206c0..73abd31f8 100644 --- a/desktop/angular/src/app/shared/netquery/connection-details/conn-details.html +++ b/desktop/angular/src/app/shared/netquery/connection-details/conn-details.html @@ -219,15 +219,6 @@ "Global" }} Settings - -
diff --git a/desktop/angular/src/app/shared/netquery/connection-helper.service.ts b/desktop/angular/src/app/shared/netquery/connection-helper.service.ts index fbe1b7692..89b8ff0a3 100644 --- a/desktop/angular/src/app/shared/netquery/connection-helper.service.ts +++ b/desktop/angular/src/app/shared/netquery/connection-helper.service.ts @@ -510,28 +510,4 @@ export class NetqueryHelper { } }); } - - /** - * Iterates of all outgoing rules and collects which domains are blocked. - * It stops collecting domains as soon as the first "allow something" rule - * is hit. - */ - // FIXME - /* - private collectBlockedDomains() { - let blockedDomains = new Set(); - - const rules = getAppSetting(this.profile!.profile!.Config, 'filter/endpoints') || []; - for (let i = 0; i < rules.length; i++) { - const rule = rules[i]; - if (rule.startsWith('+ ')) { - break; - } - - blockedDomains.add(rule.slice(2)) - } - - this.blockedDomains = Array.from(blockedDomains) - } - */ } diff --git a/desktop/angular/src/app/shared/netquery/connection-row/conn-row.html b/desktop/angular/src/app/shared/netquery/connection-row/conn-row.html index 3c721b0b6..ff09c532a 100644 --- a/desktop/angular/src/app/shared/netquery/connection-row/conn-row.html +++ b/desktop/angular/src/app/shared/netquery/connection-row/conn-row.html @@ -1,5 +1,5 @@
- diff --git a/desktop/angular/src/app/shared/netquery/connection-row/conn-row.ts b/desktop/angular/src/app/shared/netquery/connection-row/conn-row.ts index b841d116c..d5fee6634 100644 --- a/desktop/angular/src/app/shared/netquery/connection-row/conn-row.ts +++ b/desktop/angular/src/app/shared/netquery/connection-row/conn-row.ts @@ -30,21 +30,6 @@ export class SfngNetqueryConnectionRowComponent implements OnInit, OnDestroy { @Input() activeRevision: number | undefined = 0; - get isOutdated() { - // FIXME(ppacher) - return false; - /* - if (!this.conn || !this.helper.profile) { - return false; - } - if (this.helper.profile.currentProfileRevision === -1) { - // we don't know the revision counter yet ... - return false; - } - return this.conn.profile_revision !== this.helper.profile.currentProfileRevision; - */ - } - /* timeAgoTicker ticks every 10000 seconds to force a refresh of the timeAgo pipes */ timeAgoTicker: number = 0; From 491d9ea4b6cf1da5087b0b3276131822c9f96e1b Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 30 Jul 2024 17:05:50 +0200 Subject: [PATCH 44/56] Make log output more similar for the logging transition phase --- base/log/formatting.go | 87 +++++++++++++++++++++++++++++----- base/log/formatting_unix.go | 21 ++++++-- base/log/formatting_windows.go | 35 ++++++++++++-- base/log/logging.go | 6 +++ base/log/output.go | 24 +++++++++- base/log/slog.go | 65 +++++++++++++++++++++++++ 6 files changed, 215 insertions(+), 23 deletions(-) create mode 100644 base/log/slog.go diff --git a/base/log/formatting.go b/base/log/formatting.go index a9bd519d3..9394f9b1f 100644 --- a/base/log/formatting.go +++ b/base/log/formatting.go @@ -9,48 +9,91 @@ var counter uint16 const ( maxCount uint16 = 999 - timeFormat string = "060102 15:04:05.000" + timeFormat string = "2006-01-02 15:04:05.000" ) func (s Severity) String() string { switch s { case TraceLevel: - return "TRAC" + return "TRC" case DebugLevel: - return "DEBU" + return "DBG" case InfoLevel: - return "INFO" + return "INF" case WarningLevel: - return "WARN" + return "WRN" case ErrorLevel: - return "ERRO" + return "ERR" case CriticalLevel: - return "CRIT" + return "CRT" default: - return "NONE" + return "NON" } } func formatLine(line *logLine, duplicates uint64, useColor bool) string { - colorStart := "" - colorEnd := "" + var colorStart, colorEnd, colorDim, colorEndDim string if useColor { colorStart = line.level.color() colorEnd = endColor() + colorDim = dimColor() + colorEndDim = endDimColor() } counter++ var fLine string if line.line == 0 { - fLine = fmt.Sprintf("%s%s ? %s %s %03d%s%s %s", colorStart, line.timestamp.Format(timeFormat), rightArrow, line.level.String(), counter, formatDuplicates(duplicates), colorEnd, line.msg) + fLine = fmt.Sprintf( + "%s%s%s %s%s%s %s? %s %03d%s%s %s", + + colorDim, + line.timestamp.Format(timeFormat), + colorEndDim, + + colorStart, + line.level.String(), + colorEnd, + + colorDim, + + rightArrow, + + counter, + formatDuplicates(duplicates), + colorEndDim, + + line.msg, + ) } else { fLen := len(line.file) fPartStart := fLen - 10 if fPartStart < 0 { fPartStart = 0 } - fLine = fmt.Sprintf("%s%s %s:%03d %s %s %03d%s%s %s", colorStart, line.timestamp.Format(timeFormat), line.file[fPartStart:], line.line, rightArrow, line.level.String(), counter, formatDuplicates(duplicates), colorEnd, line.msg) + fLine = fmt.Sprintf( + "%s%s%s %s%s%s %s%s:%03d %s %03d%s%s %s", + + colorDim, + line.timestamp.Format(timeFormat), + colorEndDim, + + colorStart, + line.level.String(), + colorEnd, + + colorDim, + line.file[fPartStart:], + line.line, + + rightArrow, + + counter, + formatDuplicates(duplicates), + colorEndDim, + + line.msg, + ) } if line.tracer != nil { @@ -78,7 +121,25 @@ func formatLine(line *logLine, duplicates uint64, useColor bool) string { } else { d = line.tracer.logs[i+1].timestamp.Sub(action.timestamp) } - fLine += fmt.Sprintf("\n%s%19s %s:%03d %s %s%s %s", colorStart, d, action.file[fPartStart:], action.line, rightArrow, action.level.String(), colorEnd, action.msg) + fLine += fmt.Sprintf( + "\n%s%23s%s %s%s%s %s%s:%03d %s%s %s", + colorDim, + d, + colorEndDim, + + colorStart, + action.level.String(), + colorEnd, + + colorDim, + action.file[fPartStart:], + action.line, + + rightArrow, + colorEndDim, + + action.msg, + ) } } diff --git a/base/log/formatting_unix.go b/base/log/formatting_unix.go index 6be6fdcca..a19c6abc6 100644 --- a/base/log/formatting_unix.go +++ b/base/log/formatting_unix.go @@ -8,11 +8,14 @@ const ( ) const ( - colorRed = "\033[31m" - colorYellow = "\033[33m" + colorDim = "\033[2m" + colorEndDim = "\033[22m" + colorRed = "\033[91m" + colorYellow = "\033[93m" colorBlue = "\033[34m" colorMagenta = "\033[35m" colorCyan = "\033[36m" + colorGreen = "\033[92m" // Saved for later: // colorBlack = "\033[30m" //. @@ -25,7 +28,7 @@ func (s Severity) color() string { case DebugLevel: return colorCyan case InfoLevel: - return colorBlue + return colorGreen case WarningLevel: return colorYellow case ErrorLevel: @@ -42,3 +45,15 @@ func (s Severity) color() string { func endColor() string { return "\033[0m" } + +func blueColor() string { + return colorBlue +} + +func dimColor() string { + return colorDim +} + +func endDimColor() string { + return colorEndDim +} diff --git a/base/log/formatting_windows.go b/base/log/formatting_windows.go index 2c972d0a3..b3231f172 100644 --- a/base/log/formatting_windows.go +++ b/base/log/formatting_windows.go @@ -10,13 +10,17 @@ const ( ) const ( - // colorBlack = "\033[30m" - colorRed = "\033[31m" - // colorGreen = "\033[32m" - colorYellow = "\033[33m" + colorDim = "\033[2m" + colorEndDim = "\033[22m" + colorRed = "\033[91m" + colorYellow = "\033[93m" colorBlue = "\033[34m" colorMagenta = "\033[35m" colorCyan = "\033[36m" + colorGreen = "\033[92m" + + // colorBlack = "\033[30m" + // colorGreen = "\033[32m" // colorWhite = "\033[37m" ) @@ -34,7 +38,7 @@ func (s Severity) color() string { case DebugLevel: return colorCyan case InfoLevel: - return colorBlue + return colorGreen case WarningLevel: return colorYellow case ErrorLevel: @@ -54,3 +58,24 @@ func endColor() string { } return "" } + +func blueColor() string { + if colorsSupported { + return colorBlue + } + return "" +} + +func dimColor() string { + if colorsSupported { + return colorDim + } + return "" +} + +func endDimColor() string { + if colorsSupported { + return colorEndDim + } + return "" +} diff --git a/base/log/logging.go b/base/log/logging.go index fe777abad..b859bf110 100644 --- a/base/log/logging.go +++ b/base/log/logging.go @@ -142,6 +142,9 @@ func GetLogLevel() Severity { // SetLogLevel sets a new log level. Only effective after Start(). func SetLogLevel(level Severity) { atomic.StoreUint32(logLevel, uint32(level)) + + // Setup slog here for the transition period. + setupSLog(level) } // Name returns the name of the log level. @@ -199,6 +202,9 @@ func Start() (err error) { } SetLogLevel(initialLogLevel) + } else { + // Setup slog here for the transition period. + setupSLog(GetLogLevel()) } // get and set file loglevels diff --git a/base/log/output.go b/base/log/output.go index d8a29a406..f80fe3519 100644 --- a/base/log/output.go +++ b/base/log/output.go @@ -100,7 +100,17 @@ func defaultColorFormater(line Message, duplicates uint64) string { } func startWriter() { - fmt.Printf("%s%s %s BOF%s\n", InfoLevel.color(), time.Now().Format(timeFormat), rightArrow, endColor()) + fmt.Printf( + "%s%s%s %sBOF %s%s\n", + + dimColor(), + time.Now().Format(timeFormat), + endDimColor(), + + blueColor(), + rightArrow, + endColor(), + ) shutdownWaitGroup.Add(1) go writerManager() @@ -225,7 +235,17 @@ func finalizeWriting() { case line := <-logBuffer: adapter.Write(line, 0) case <-time.After(10 * time.Millisecond): - fmt.Printf("%s%s %s EOF%s\n", InfoLevel.color(), time.Now().Format(timeFormat), leftArrow, endColor()) + fmt.Printf( + "%s%s%s %sEOF %s%s\n", + + dimColor(), + time.Now().Format(timeFormat), + endDimColor(), + + blueColor(), + leftArrow, + endColor(), + ) return } } diff --git a/base/log/slog.go b/base/log/slog.go new file mode 100644 index 000000000..d0f09aad9 --- /dev/null +++ b/base/log/slog.go @@ -0,0 +1,65 @@ +package log + +import ( + "log/slog" + "os" + "runtime" + + "github.com/lmittmann/tint" + "github.com/mattn/go-colorable" + "github.com/mattn/go-isatty" +) + +func setupSLog(logLevel Severity) { + // Convert to slog level. + var level slog.Level + switch logLevel { + case TraceLevel: + level = slog.LevelDebug + case DebugLevel: + level = slog.LevelDebug + case InfoLevel: + level = slog.LevelInfo + case WarningLevel: + level = slog.LevelWarn + case ErrorLevel: + level = slog.LevelError + case CriticalLevel: + level = slog.LevelError + } + + // Setup logging. + // Define output. + logOutput := os.Stdout + // Create handler depending on OS. + var logHandler slog.Handler + switch runtime.GOOS { + case "windows": + logHandler = tint.NewHandler( + colorable.NewColorable(logOutput), + &tint.Options{ + AddSource: true, + Level: level, + TimeFormat: timeFormat, + }, + ) + case "linux": + logHandler = tint.NewHandler(logOutput, &tint.Options{ + AddSource: true, + Level: level, + TimeFormat: timeFormat, + NoColor: !isatty.IsTerminal(logOutput.Fd()), + }) + default: + logHandler = tint.NewHandler(os.Stdout, &tint.Options{ + AddSource: true, + Level: level, + TimeFormat: timeFormat, + NoColor: true, + }) + } + + // Set as default logger. + slog.SetDefault(slog.New(logHandler)) + slog.SetLogLoggerLevel(level) +} From a286c2ca675fc30b5f3538343b6091f13851a849 Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 30 Jul 2024 17:07:51 +0200 Subject: [PATCH 45/56] Switch spn hub and observer cmds to new module system --- base/api/main.go | 25 +- base/config/main.go | 3 + base/config/module.go | 1 + base/database/main.go | 8 - base/database/registry.go | 79 ----- cmds/hub/main.go | 196 ++++++++--- cmds/observation-hub/main.go | 139 +++++++- cmds/observation-hub/observe.go | 2 - cmds/portmaster-core/main.go | 50 +-- service/core/base/global.go | 16 +- service/core/base/module.go | 10 +- service/firewall/inspection/inspection.go | 2 +- service/instance.go | 89 ++--- service/mgr/group.go | 29 +- service/mgr/group_ext.go | 4 +- service/netquery/module_api.go | 2 +- service/updates/module.go | 2 + service/updates/notify.go | 12 + spn/access/module.go | 60 ++-- spn/access/zones.go | 2 +- spn/captain/api.go | 5 + spn/captain/client.go | 5 +- spn/captain/config.go | 6 +- spn/captain/establish.go | 4 +- spn/captain/module.go | 116 +++--- spn/conf/mode.go | 25 +- spn/instance.go | 408 ++++++++++++++++++++++ spn/navigator/module.go | 7 +- spn/ships/http_shared.go | 3 +- spn/sluice/module.go | 2 +- 30 files changed, 947 insertions(+), 365 deletions(-) create mode 100644 spn/instance.go diff --git a/base/api/main.go b/base/api/main.go index 293997716..cc2187788 100644 --- a/base/api/main.go +++ b/base/api/main.go @@ -23,27 +23,30 @@ func init() { } func prep() error { - if exportEndpoints { - module.instance.SetCmdLineOperation(exportEndpointsCmd) - } - - if getDefaultListenAddress() == "" { - return errors.New("no default listen address for api available") - } - + // Register endpoints. if err := registerConfig(); err != nil { return err } - if err := registerDebugEndpoints(); err != nil { return err } - if err := registerConfigEndpoints(); err != nil { return err } + if err := registerMetaEndpoints(); err != nil { + return err + } + + if exportEndpoints { + module.instance.SetCmdLineOperation(exportEndpointsCmd) + return mgr.ErrExecuteCmdLineOp + } + + if getDefaultListenAddress() == "" { + return errors.New("no default listen address for api available") + } - return registerMetaEndpoints() + return nil } func start() error { diff --git a/base/config/main.go b/base/config/main.go index b4eef3846..0ed0b7e6a 100644 --- a/base/config/main.go +++ b/base/config/main.go @@ -13,6 +13,7 @@ import ( "github.com/safing/portmaster/base/dataroot" "github.com/safing/portmaster/base/utils" "github.com/safing/portmaster/base/utils/debug" + "github.com/safing/portmaster/service/mgr" ) // ChangeEvent is the name of the config change event. @@ -43,6 +44,7 @@ func prep() error { if exportConfig { module.instance.SetCmdLineOperation(exportConfigCmd) + return mgr.ErrExecuteCmdLineOp } return registerBasicOptions() @@ -135,6 +137,7 @@ func GetActiveConfigValues() map[string]interface{} { return values } +// InitializeUnitTestDataroot initializes a new random tmp directory for running tests. func InitializeUnitTestDataroot(testName string) (string, error) { basePath, err := os.MkdirTemp("", fmt.Sprintf("portmaster-%s", testName)) if err != nil { diff --git a/base/config/module.go b/base/config/module.go index afd7cc397..465e54549 100644 --- a/base/config/module.go +++ b/base/config/module.go @@ -16,6 +16,7 @@ type Config struct { EventConfigChange *mgr.EventMgr[struct{}] } +// Manager returns the module's manager. func (u *Config) Manager() *mgr.Manager { return u.mgr } diff --git a/base/database/main.go b/base/database/main.go index 9a03420db..f84a01086 100644 --- a/base/database/main.go +++ b/base/database/main.go @@ -3,7 +3,6 @@ package database import ( "errors" "fmt" - "path/filepath" "github.com/tevino/abool" @@ -41,13 +40,6 @@ func Initialize(dirStructureRoot *utils.DirStructure) error { return fmt.Errorf("could not create/open database directory (%s): %w", rootStructure.Path, err) } - if registryPersistence.IsSet() { - err = loadRegistry() - if err != nil { - return fmt.Errorf("could not load database registry (%s): %w", filepath.Join(rootStructure.Path, registryFileName), err) - } - } - return nil } return errors.New("database already initialized") diff --git a/base/database/registry.go b/base/database/registry.go index 44dcbae73..6c8f59e3b 100644 --- a/base/database/registry.go +++ b/base/database/registry.go @@ -1,12 +1,8 @@ package database import ( - "encoding/json" "errors" "fmt" - "io/fs" - "os" - "path" "regexp" "sync" "time" @@ -33,10 +29,6 @@ var ( // the description and the primary API will be // updated and the effective object will be returned. func Register(db *Database) (*Database, error) { - if !initialized.IsSet() { - return nil, errors.New("database not initialized") - } - registryLock.Lock() defer registryLock.Unlock() @@ -72,10 +64,6 @@ func Register(db *Database) (*Database, error) { if ok { registeredDB.Updated() } - err := saveRegistry(false) - if err != nil { - return nil, err - } } if ok { @@ -99,70 +87,3 @@ func getDatabase(name string) (*Database, error) { return registeredDB, nil } - -// EnableRegistryPersistence enables persistence of the database registry. -func EnableRegistryPersistence() { - if registryPersistence.SetToIf(false, true) { - // start registry writer - go registryWriter() - // TODO: make an initial write if database system is already initialized - } -} - -func loadRegistry() error { - registryLock.Lock() - defer registryLock.Unlock() - - // read file - filePath := path.Join(rootStructure.Path, registryFileName) - data, err := os.ReadFile(filePath) - if err != nil { - if errors.Is(err, fs.ErrNotExist) { - return nil - } - return err - } - - // parse - databases := make(map[string]*Database) - err = json.Unmarshal(data, &databases) - if err != nil { - return err - } - - // set - registry = databases - return nil -} - -func saveRegistry(lock bool) error { - if lock { - registryLock.Lock() - defer registryLock.Unlock() - } - - // marshal - data, err := json.MarshalIndent(registry, "", "\t") - if err != nil { - return err - } - - // write file - // TODO: write atomically (best effort) - filePath := path.Join(rootStructure.Path, registryFileName) - return os.WriteFile(filePath, data, 0o0600) -} - -func registryWriter() { - for { - select { - case <-time.After(1 * time.Hour): - if writeRegistrySoon.SetToIf(true, false) { - _ = saveRegistry(true) - } - case <-shutdownSignal: - _ = saveRegistry(true) - return - } - } -} diff --git a/cmds/hub/main.go b/cmds/hub/main.go index 4180e74d4..3db002b3d 100644 --- a/cmds/hub/main.go +++ b/cmds/hub/main.go @@ -1,65 +1,157 @@ package main import ( + "errors" "flag" - // "fmt" - // "os" - // "runtime" - - // "github.com/safing/portmaster/base/info" - // "github.com/safing/portmaster/base/metrics" - _ "github.com/safing/portmaster/service/core/base" - _ "github.com/safing/portmaster/service/ui" + "fmt" + "io" + "log/slog" + "os" + "os/signal" + "runtime" + "runtime/pprof" + "syscall" + "time" + + "github.com/safing/portmaster/base/info" + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/metrics" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/updates" - // "github.com/safing/portmaster/service/updates/helper" - _ "github.com/safing/portmaster/spn/captain" - // "github.com/safing/portmaster/spn/conf" + "github.com/safing/portmaster/service/updates/helper" + "github.com/safing/portmaster/spn" + "github.com/safing/portmaster/spn/conf" ) func init() { flag.BoolVar(&updates.RebootOnRestart, "reboot-on-restart", false, "reboot server on auto-upgrade") } +var sigUSR1 = syscall.Signal(0xa) + func main() { - // FIXME: rewrite so it fits the new module system - // info.Set("SPN Hub", "0.7.7", "GPLv3") - - // // Configure metrics. - // _ = metrics.SetNamespace("hub") - - // // Configure updating. - // updates.UserAgent = fmt.Sprintf("SPN Hub (%s %s)", runtime.GOOS, runtime.GOARCH) - // helper.IntelOnly() - - // // Configure SPN mode. - // conf.EnablePublicHub(true) - // conf.EnableClient(false) - - // // Disable module management, as we want to start all modules. - // modules.DisableModuleManagement() - - // // Configure microtask threshold. - // // Scale with CPU/GOMAXPROCS count, but keep a baseline and minimum: - // // CPUs -> MicroTasks - // // 0 -> 8 (increased to minimum) - // // 1 -> 8 (increased to minimum) - // // 2 -> 8 - // // 3 -> 10 - // // 4 -> 12 - // // 8 -> 20 - // // 16 -> 36 - // // - // // Start with number of GOMAXPROCS. - // microTasksThreshold := runtime.GOMAXPROCS(0) * 2 - // // Use at least 4 microtasks based on GOMAXPROCS. - // if microTasksThreshold < 4 { - // microTasksThreshold = 4 - // } - // // Add a 4 microtask baseline. - // microTasksThreshold += 4 - // // Set threshold. - // modules.SetMaxConcurrentMicroTasks(microTasksThreshold) - - // // Start. - // os.Exit(run.Run()) + flag.Parse() + + // Set name and license. + info.Set("SPN Hub", "", "GPLv3") + + // Configure metrics. + _ = metrics.SetNamespace("hub") + + // Configure user agent and updates. + updates.UserAgent = fmt.Sprintf("SPN Hub (%s %s)", runtime.GOOS, runtime.GOARCH) + helper.IntelOnly() + + // Set SPN public hub mode. + conf.EnablePublicHub(true) + + // Set default log level. + log.SetLogLevel(log.WarningLevel) + _ = log.Start() + + // Create instance. + var execCmdLine bool + instance, err := spn.New() + switch { + case err == nil: + // Continue + case errors.Is(err, mgr.ErrExecuteCmdLineOp): + execCmdLine = true + default: + fmt.Printf("error creating an instance: %s\n", err) + os.Exit(2) + } + + // Execute command line operation, if requested or available. + switch { + case !execCmdLine: + // Run service. + case instance.CommandLineOperation == nil: + fmt.Println("command line operation execution requested, but not set") + os.Exit(3) + default: + // Run the function and exit. + err = instance.CommandLineOperation() + if err != nil { + fmt.Fprintf(os.Stderr, "command line operation failed: %s\n", err) + os.Exit(3) + } + os.Exit(0) + } + + // Start + go func() { + err = instance.Start() + if err != nil { + fmt.Printf("instance start failed: %s\n", err) + os.Exit(1) + } + }() + + // Wait for signal. + signalCh := make(chan os.Signal, 1) + signal.Notify( + signalCh, + os.Interrupt, + syscall.SIGHUP, + syscall.SIGINT, + syscall.SIGTERM, + syscall.SIGQUIT, + sigUSR1, + ) + + select { + case sig := <-signalCh: + // Only print and continue to wait if SIGUSR1 + if sig == sigUSR1 { + printStackTo(os.Stderr, "PRINTING STACK ON REQUEST") + } else { + fmt.Println(" ") // CLI output. + slog.Warn("program was interrupted, stopping") + } + + case <-instance.Stopped(): + log.Shutdown() + os.Exit(instance.ExitCode()) + } + + // Catch signals during shutdown. + // Rapid unplanned disassembly after 5 interrupts. + go func() { + forceCnt := 5 + for { + <-signalCh + forceCnt-- + if forceCnt > 0 { + fmt.Printf(" again, but already shutting down - %d more to force\n", forceCnt) + } else { + printStackTo(os.Stderr, "PRINTING STACK ON FORCED EXIT") + os.Exit(1) + } + } + }() + + // Rapid unplanned disassembly after 3 minutes. + go func() { + time.Sleep(3 * time.Minute) + printStackTo(os.Stderr, "PRINTING STACK - TAKING TOO LONG FOR SHUTDOWN") + os.Exit(1) + }() + + // Stop instance. + if err := instance.Stop(); err != nil { + slog.Error("failed to stop", "err", err) + } + log.Shutdown() + os.Exit(instance.ExitCode()) +} + +func printStackTo(writer io.Writer, msg string) { + _, err := fmt.Fprintf(writer, "===== %s =====\n", msg) + if err == nil { + err = pprof.Lookup("goroutine").WriteTo(writer, 1) + } + if err != nil { + slog.Error("failed to write stack trace", "err", err) + } } diff --git a/cmds/observation-hub/main.go b/cmds/observation-hub/main.go index bf83a4ab4..b7a8ed712 100644 --- a/cmds/observation-hub/main.go +++ b/cmds/observation-hub/main.go @@ -1,53 +1,160 @@ package main import ( + "errors" + "flag" "fmt" + "io" + "log/slog" + "os" + "os/signal" "runtime" + "runtime/pprof" + "syscall" + "time" "github.com/safing/portmaster/base/api" "github.com/safing/portmaster/base/info" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/metrics" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/updates" "github.com/safing/portmaster/service/updates/helper" + "github.com/safing/portmaster/spn" "github.com/safing/portmaster/spn/captain" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/sluice" ) +var sigUSR1 = syscall.Signal(0xa) + func main() { - info.Set("SPN Observation Hub", "0.7.1", "GPLv3") + flag.Parse() + + info.Set("SPN Observation Hub", "", "GPLv3") // Configure metrics. _ = metrics.SetNamespace("observer") - // Configure user agent. + // Configure user agent and updates. updates.UserAgent = fmt.Sprintf("SPN Observation Hub (%s %s)", runtime.GOOS, runtime.GOARCH) helper.IntelOnly() // Configure SPN mode. conf.EnableClient(true) - conf.EnablePublicHub(false) captain.DisableAccount = true // Disable unneeded listeners. sluice.EnableListener = false api.EnableServer = false - /// TODO(vladimir) initialize dependency modules + // Set default log level. + log.SetLogLevel(log.WarningLevel) + _ = log.Start() - // Disable module management, as we want to start all modules. - // module.DisableModuleManagement() - module, err := New(struct{}{}) - if err != nil { - fmt.Printf("error creating observer: %s\n", err) - return + // Create instance. + var execCmdLine bool + instance, err := spn.New() + switch { + case err == nil: + // Continue + case errors.Is(err, mgr.ErrExecuteCmdLineOp): + execCmdLine = true + default: + fmt.Printf("error creating an instance: %s\n", err) + os.Exit(2) } - err = module.Start() - if err != nil { - fmt.Printf("failed to start observer: %s\n", err) - return + + // Execute command line operation, if requested or available. + switch { + case !execCmdLine: + // Run service. + case instance.CommandLineOperation == nil: + fmt.Println("command line operation execution requested, but not set") + os.Exit(3) + default: + // Run the function and exit. + err = instance.CommandLineOperation() + if err != nil { + fmt.Fprintf(os.Stderr, "command line operation failed: %s\n", err) + os.Exit(3) + } + os.Exit(0) + } + + // Start + go func() { + err = instance.Start() + if err != nil { + fmt.Printf("instance start failed: %s\n", err) + os.Exit(1) + } + }() + + // Wait for signal. + signalCh := make(chan os.Signal, 1) + signal.Notify( + signalCh, + os.Interrupt, + syscall.SIGHUP, + syscall.SIGINT, + syscall.SIGTERM, + syscall.SIGQUIT, + sigUSR1, + ) + + select { + case sig := <-signalCh: + // Only print and continue to wait if SIGUSR1 + if sig == sigUSR1 { + printStackTo(os.Stderr, "PRINTING STACK ON REQUEST") + } else { + fmt.Println(" ") // CLI output. + slog.Warn("program was interrupted, stopping") + } + + case <-instance.Stopped(): + log.Shutdown() + os.Exit(instance.ExitCode()) } - // Start. - // os.Exit(run.Start()) + // Catch signals during shutdown. + // Rapid unplanned disassembly after 5 interrupts. + go func() { + forceCnt := 5 + for { + <-signalCh + forceCnt-- + if forceCnt > 0 { + fmt.Printf(" again, but already shutting down - %d more to force\n", forceCnt) + } else { + printStackTo(os.Stderr, "PRINTING STACK ON FORCED EXIT") + os.Exit(1) + } + } + }() + + // Rapid unplanned disassembly after 3 minutes. + go func() { + time.Sleep(3 * time.Minute) + printStackTo(os.Stderr, "PRINTING STACK - TAKING TOO LONG FOR SHUTDOWN") + os.Exit(1) + }() + + // Stop instance. + if err := instance.Stop(); err != nil { + slog.Error("failed to stop", "err", err) + } + log.Shutdown() + os.Exit(instance.ExitCode()) +} + +func printStackTo(writer io.Writer, msg string) { + _, err := fmt.Fprintf(writer, "===== %s =====\n", msg) + if err == nil { + err = pprof.Lookup("goroutine").WriteTo(writer, 1) + } + if err != nil { + slog.Error("failed to write stack trace", "err", err) + } } diff --git a/cmds/observation-hub/observe.go b/cmds/observation-hub/observe.go index 76f0c0a96..8835858e8 100644 --- a/cmds/observation-hub/observe.go +++ b/cmds/observation-hub/observe.go @@ -55,8 +55,6 @@ var ( ) func init() { - // observerModule = modules.Register("observer", prepObserver, startObserver, nil, "captain", "apprise") - flag.BoolVar(&reportAllChanges, "report-all-changes", false, "report all changes, no just interesting ones") flag.StringVar(&reportingDelayFlag, "reporting-delay", "10m", "delay reports to summarize changes") } diff --git a/cmds/portmaster-core/main.go b/cmds/portmaster-core/main.go index 407f67985..277046cfe 100644 --- a/cmds/portmaster-core/main.go +++ b/cmds/portmaster-core/main.go @@ -2,6 +2,7 @@ package main import ( + "errors" "flag" "fmt" "io" @@ -17,16 +18,9 @@ import ( "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/metrics" "github.com/safing/portmaster/service" - "github.com/safing/portmaster/service/core/base" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/updates" "github.com/safing/portmaster/spn/conf" - - // Include packages here. - _ "github.com/safing/portmaster/service/core" - _ "github.com/safing/portmaster/service/firewall" - _ "github.com/safing/portmaster/service/nameserver" - _ "github.com/safing/portmaster/service/ui" - _ "github.com/safing/portmaster/spn/captain" ) var sigUSR1 = syscall.Signal(0xa) @@ -37,10 +31,6 @@ func main() { // set information info.Set("Portmaster", "", "GPLv3") - // Set default log level. - log.SetLogLevel(log.WarningLevel) - _ = log.Start() - // Configure metrics. _ = metrics.SetNamespace("portmaster") @@ -49,32 +39,42 @@ func main() { // enable SPN client mode conf.EnableClient(true) - - // Prep - err := base.GlobalPrep() - if err != nil { - fmt.Printf("global prep failed: %s\n", err) - return - } + conf.EnableIntegration(true) // Create instance. + var execCmdLine bool instance, err := service.New(&service.ServiceConfig{}) - if err != nil { + switch { + case err == nil: + // Continue + case errors.Is(err, mgr.ErrExecuteCmdLineOp): + execCmdLine = true + default: fmt.Printf("error creating an instance: %s\n", err) os.Exit(2) } - // Execute command line operation, if available. - if instance.CommandLineOperation != nil { + // Execute command line operation, if requested or available. + switch { + case !execCmdLine: + // Run service. + case instance.CommandLineOperation == nil: + fmt.Println("command line operation execution requested, but not set") + os.Exit(3) + default: // Run the function and exit. err = instance.CommandLineOperation() if err != nil { - fmt.Fprintf(os.Stderr, "cmdline operation failed: %s\n", err) + fmt.Fprintf(os.Stderr, "command line operation failed: %s\n", err) os.Exit(3) } os.Exit(0) } + // Set default log level. + log.SetLogLevel(log.WarningLevel) + _ = log.Start() + // Start go func() { err = instance.Start() @@ -107,6 +107,7 @@ func main() { } case <-instance.Stopped(): + log.Shutdown() os.Exit(instance.ExitCode()) } @@ -135,8 +136,9 @@ func main() { // Stop instance. if err := instance.Stop(); err != nil { - slog.Error("failed to stop portmaster", "err", err) + slog.Error("failed to stop", "err", err) } + log.Shutdown() os.Exit(instance.ExitCode()) } diff --git a/service/core/base/global.go b/service/core/base/global.go index 727fbd3de..3b1cc82f5 100644 --- a/service/core/base/global.go +++ b/service/core/base/global.go @@ -8,10 +8,9 @@ import ( "github.com/safing/portmaster/base/api" "github.com/safing/portmaster/base/dataroot" "github.com/safing/portmaster/base/info" + "github.com/safing/portmaster/service/mgr" ) -var ErrCleanExit = errors.New("clean exit requested") - // Default Values (changeable for testing). var ( DefaultAPIListenAddress = "127.0.0.1:817" @@ -25,11 +24,9 @@ func init() { flag.StringVar(&dataDir, "data", "", "set data directory") flag.StringVar(&databaseDir, "db", "", "alias to --data (deprecated)") flag.BoolVar(&showVersion, "version", false, "show version and exit") - - // modules.SetGlobalPrepFn(globalPrep) } -func GlobalPrep() error { +func prep(instance instance) error { // check if meta info is ok err := info.CheckVersion() if err != nil { @@ -38,8 +35,8 @@ func GlobalPrep() error { // print version if showVersion { - fmt.Println(info.FullVersion()) - return ErrCleanExit + instance.SetCmdLineOperation(printVersion) + return mgr.ErrExecuteCmdLineOp } // check data root @@ -68,3 +65,8 @@ func GlobalPrep() error { return nil } + +func printVersion() error { + fmt.Println(info.FullVersion()) + return nil +} diff --git a/service/core/base/module.go b/service/core/base/module.go index 848ea97e8..1082e2bfd 100644 --- a/service/core/base/module.go +++ b/service/core/base/module.go @@ -4,9 +4,6 @@ import ( "errors" "sync/atomic" - _ "github.com/safing/portmaster/base/config" - _ "github.com/safing/portmaster/base/metrics" - _ "github.com/safing/portmaster/base/rng" "github.com/safing/portmaster/service/mgr" ) @@ -46,6 +43,9 @@ func New(instance instance) (*Base, error) { instance: instance, } + if err := prep(instance); err != nil { + return nil, err + } if err := registerDatabases(); err != nil { return nil, err } @@ -53,4 +53,6 @@ func New(instance instance) (*Base, error) { return module, nil } -type instance interface{} +type instance interface { + SetCmdLineOperation(f func() error) +} diff --git a/service/firewall/inspection/inspection.go b/service/firewall/inspection/inspection.go index 44855ba49..a8a2448fa 100644 --- a/service/firewall/inspection/inspection.go +++ b/service/firewall/inspection/inspection.go @@ -7,7 +7,7 @@ import ( "github.com/safing/portmaster/service/network/packet" ) -//nolint:golint,stylecheck // FIXME +//nolint:golint,stylecheck const ( DO_NOTHING uint8 = iota BLOCK_PACKET diff --git a/service/instance.go b/service/instance.go index aa9adb4be..f4b7d3b76 100644 --- a/service/instance.go +++ b/service/instance.go @@ -109,161 +109,162 @@ func New(svcCfg *ServiceConfig) (*Instance, error) { var err error // Base modules + instance.base, err = base.New(instance) + if err != nil { + return instance, fmt.Errorf("create base module: %w", err) + } instance.database, err = dbmodule.New(instance) if err != nil { - return nil, fmt.Errorf("create database module: %w", err) + return instance, fmt.Errorf("create database module: %w", err) } instance.config, err = config.New(instance) if err != nil { - return nil, fmt.Errorf("create config module: %w", err) + return instance, fmt.Errorf("create config module: %w", err) } instance.api, err = api.New(instance) if err != nil { - return nil, fmt.Errorf("create api module: %w", err) + return instance, fmt.Errorf("create api module: %w", err) } instance.metrics, err = metrics.New(instance) if err != nil { - return nil, fmt.Errorf("create metrics module: %w", err) + return instance, fmt.Errorf("create metrics module: %w", err) } instance.runtime, err = runtime.New(instance) if err != nil { - return nil, fmt.Errorf("create runtime module: %w", err) + return instance, fmt.Errorf("create runtime module: %w", err) } instance.notifications, err = notifications.New(instance) if err != nil { - return nil, fmt.Errorf("create runtime module: %w", err) + return instance, fmt.Errorf("create runtime module: %w", err) } instance.rng, err = rng.New(instance) if err != nil { - return nil, fmt.Errorf("create rng module: %w", err) - } - instance.base, err = base.New(instance) - if err != nil { - return nil, fmt.Errorf("create base module: %w", err) + return instance, fmt.Errorf("create rng module: %w", err) } // Service modules instance.core, err = core.New(instance) if err != nil { - return nil, fmt.Errorf("create core module: %w", err) + return instance, fmt.Errorf("create core module: %w", err) } instance.updates, err = updates.New(instance) if err != nil { - return nil, fmt.Errorf("create updates module: %w", err) + return instance, fmt.Errorf("create updates module: %w", err) } instance.geoip, err = geoip.New(instance) if err != nil { - return nil, fmt.Errorf("create customlist module: %w", err) + return instance, fmt.Errorf("create customlist module: %w", err) } instance.netenv, err = netenv.New(instance) if err != nil { - return nil, fmt.Errorf("create netenv module: %w", err) + return instance, fmt.Errorf("create netenv module: %w", err) } instance.ui, err = ui.New(instance) if err != nil { - return nil, fmt.Errorf("create ui module: %w", err) + return instance, fmt.Errorf("create ui module: %w", err) } instance.profile, err = profile.NewModule(instance) if err != nil { - return nil, fmt.Errorf("create profile module: %w", err) + return instance, fmt.Errorf("create profile module: %w", err) } instance.network, err = network.New(instance) if err != nil { - return nil, fmt.Errorf("create network module: %w", err) + return instance, fmt.Errorf("create network module: %w", err) } instance.netquery, err = netquery.NewModule(instance) if err != nil { - return nil, fmt.Errorf("create netquery module: %w", err) + return instance, fmt.Errorf("create netquery module: %w", err) } instance.firewall, err = firewall.New(instance) if err != nil { - return nil, fmt.Errorf("create firewall module: %w", err) + return instance, fmt.Errorf("create firewall module: %w", err) } instance.filterLists, err = filterlists.New(instance) if err != nil { - return nil, fmt.Errorf("create filterLists module: %w", err) + return instance, fmt.Errorf("create filterLists module: %w", err) } instance.interception, err = interception.New(instance) if err != nil { - return nil, fmt.Errorf("create interception module: %w", err) + return instance, fmt.Errorf("create interception module: %w", err) } instance.customlist, err = customlists.New(instance) if err != nil { - return nil, fmt.Errorf("create customlist module: %w", err) + return instance, fmt.Errorf("create customlist module: %w", err) } instance.status, err = status.New(instance) if err != nil { - return nil, fmt.Errorf("create status module: %w", err) + return instance, fmt.Errorf("create status module: %w", err) } instance.broadcasts, err = broadcasts.New(instance) if err != nil { - return nil, fmt.Errorf("create broadcasts module: %w", err) + return instance, fmt.Errorf("create broadcasts module: %w", err) } instance.compat, err = compat.New(instance) if err != nil { - return nil, fmt.Errorf("create compat module: %w", err) + return instance, fmt.Errorf("create compat module: %w", err) } instance.nameserver, err = nameserver.New(instance) if err != nil { - return nil, fmt.Errorf("create nameserver module: %w", err) + return instance, fmt.Errorf("create nameserver module: %w", err) } instance.process, err = process.New(instance) if err != nil { - return nil, fmt.Errorf("create process module: %w", err) + return instance, fmt.Errorf("create process module: %w", err) } instance.resolver, err = resolver.New(instance) if err != nil { - return nil, fmt.Errorf("create resolver module: %w", err) + return instance, fmt.Errorf("create resolver module: %w", err) } instance.sync, err = sync.New(instance) if err != nil { - return nil, fmt.Errorf("create sync module: %w", err) + return instance, fmt.Errorf("create sync module: %w", err) } instance.access, err = access.New(instance) if err != nil { - return nil, fmt.Errorf("create access module: %w", err) + return instance, fmt.Errorf("create access module: %w", err) } // SPN modules instance.cabin, err = cabin.New(instance) if err != nil { - return nil, fmt.Errorf("create cabin module: %w", err) + return instance, fmt.Errorf("create cabin module: %w", err) } instance.navigator, err = navigator.New(instance) if err != nil { - return nil, fmt.Errorf("create navigator module: %w", err) + return instance, fmt.Errorf("create navigator module: %w", err) } instance.captain, err = captain.New(instance) if err != nil { - return nil, fmt.Errorf("create captain module: %w", err) + return instance, fmt.Errorf("create captain module: %w", err) } instance.crew, err = crew.New(instance) if err != nil { - return nil, fmt.Errorf("create crew module: %w", err) + return instance, fmt.Errorf("create crew module: %w", err) } instance.docks, err = docks.New(instance) if err != nil { - return nil, fmt.Errorf("create docks module: %w", err) + return instance, fmt.Errorf("create docks module: %w", err) } instance.patrol, err = patrol.New(instance) if err != nil { - return nil, fmt.Errorf("create patrol module: %w", err) + return instance, fmt.Errorf("create patrol module: %w", err) } instance.ships, err = ships.New(instance) if err != nil { - return nil, fmt.Errorf("create ships module: %w", err) + return instance, fmt.Errorf("create ships module: %w", err) } instance.sluice, err = sluice.New(instance) if err != nil { - return nil, fmt.Errorf("create sluice module: %w", err) + return instance, fmt.Errorf("create sluice module: %w", err) } instance.terminal, err = terminal.New(instance) if err != nil { - return nil, fmt.Errorf("create terminal module: %w", err) + return instance, fmt.Errorf("create terminal module: %w", err) } // Add all modules to instance group. instance.serviceGroup = mgr.NewGroup( + instance.base, instance.database, instance.config, instance.api, @@ -271,7 +272,6 @@ func New(svcCfg *ServiceConfig) (*Instance, error) { instance.runtime, instance.notifications, instance.rng, - instance.base, instance.core, instance.updates, @@ -293,6 +293,7 @@ func New(svcCfg *ServiceConfig) (*Instance, error) { instance.process, instance.resolver, instance.sync, + instance.access, ) @@ -472,8 +473,8 @@ func (i *Instance) Status() *status.Status { } // Broadcasts returns the broadcast module. -func (i *Instance) Broadcasts() *status.Status { - return i.status +func (i *Instance) Broadcasts() *broadcasts.Broadcasts { + return i.broadcasts } // Compat returns the compat module. @@ -584,7 +585,7 @@ func (i *Instance) Shutdown(exitCode int) { m.Go("shutdown", func(w *mgr.WorkerCtx) error { for { if err := i.Stop(); err != nil { - w.Error("failed to shutdown", "error", err, "retry", "1s") + w.Error("failed to shutdown", "err", err, "retry", "1s") time.Sleep(1 * time.Second) } else { return nil diff --git a/service/mgr/group.go b/service/mgr/group.go index b52798db8..4a97873bd 100644 --- a/service/mgr/group.go +++ b/service/mgr/group.go @@ -16,6 +16,10 @@ var ( // ErrInvalidGroupState is returned when a group is in an invalid state and cannot be recovered. ErrInvalidGroupState = errors.New("invalid group state") + + // ErrExecuteCmdLineOp is returned when modules are created, but request + // execution of a (somewhere else set) command line operation instead. + ErrExecuteCmdLineOp = errors.New("command line operation execution requested") ) const ( @@ -121,22 +125,26 @@ func (g *Group) Start() error { // Start modules. for i, m := range g.modules { - m.mgr.Info("starting") + m.mgr.Debug("starting") startTime := time.Now() err := m.mgr.Do("start module "+m.mgr.name, func(_ *WorkerCtx) error { - return m.module.Start() + return m.module.Start() //nolint:scopelint // Execution is synchronous. }) if err != nil { + m.mgr.Error( + "failed to start", + "err", err, + "time", time.Since(startTime), + ) if !g.stopFrom(i) { g.state.Store(groupStateInvalid) } else { g.state.Store(groupStateOff) } - return fmt.Errorf("failed to start %s: %w", makeModuleName(m.module), err) + return fmt.Errorf("failed to start %s: %w", m.mgr.name, err) } - duration := time.Since(startTime) - m.mgr.Info("started", "time", duration.String()) + m.mgr.Info("started", "time", time.Since(startTime)) } g.state.Store(groupStateRunning) @@ -175,23 +183,30 @@ func (g *Group) stopFrom(index int) (ok bool) { // Stop modules. for i := index; i >= 0; i-- { m := g.modules[i] + m.mgr.Debug("stopping") + startTime := time.Now() err := m.mgr.Do("stop module "+m.mgr.name, func(_ *WorkerCtx) error { return m.module.Stop() }) if err != nil { - m.mgr.Error("failed to stop", "err", err) + m.mgr.Error( + "failed to stop", + "err", err, + "time", time.Since(startTime), + ) ok = false } m.mgr.Cancel() if m.mgr.WaitForWorkers(0) { - m.mgr.Info("stopped") + m.mgr.Info("stopped", "time", time.Since(startTime)) } else { ok = false m.mgr.Error( "failed to stop", "err", "timed out", "workerCnt", m.mgr.workerCnt.Load(), + "time", time.Since(startTime), ) } } diff --git a/service/mgr/group_ext.go b/service/mgr/group_ext.go index dcd17236e..b965f5946 100644 --- a/service/mgr/group_ext.go +++ b/service/mgr/group_ext.go @@ -47,7 +47,7 @@ func (eg *ExtendedGroup) EnsureStartedWorker(wCtx *WorkerCtx) error { case err == nil: return nil case errors.Is(err, ErrInvalidGroupState): - wCtx.Debug("group start delayed", "error", err) + wCtx.Debug("group start delayed", "err", err) default: return err } @@ -78,7 +78,7 @@ func (eg *ExtendedGroup) EnsureStoppedWorker(wCtx *WorkerCtx) error { case err == nil: return nil case errors.Is(err, ErrInvalidGroupState): - wCtx.Debug("group stop delayed", "error", err) + wCtx.Debug("group stop delayed", "err", err) default: return err } diff --git a/service/netquery/module_api.go b/service/netquery/module_api.go index 9dde73002..497226b07 100644 --- a/service/netquery/module_api.go +++ b/service/netquery/module_api.go @@ -263,7 +263,7 @@ func (nq *NetQuery) Stop() error { // Cacnel the module context. nq.mgr.Cancel() // Wait for all workers before we start the shutdown. - nq.mgr.WaitForWorkers(time.Minute) + nq.mgr.WaitForWorkersFromStop(time.Minute) // we don't use the module ctx here because it is already canceled. // just give the clean up 1 minute to happen and abort otherwise. diff --git a/service/updates/module.go b/service/updates/module.go index 7d5da2b72..a1fdab159 100644 --- a/service/updates/module.go +++ b/service/updates/module.go @@ -6,6 +6,7 @@ import ( "github.com/safing/portmaster/base/api" "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/service/mgr" ) @@ -76,4 +77,5 @@ type instance interface { API() *api.API Config() *config.Config Shutdown(exitCode int) + Notifications() *notifications.Notifications } diff --git a/service/updates/notify.go b/service/updates/notify.go index 30b2bd32b..076eea347 100644 --- a/service/updates/notify.go +++ b/service/updates/notify.go @@ -21,7 +21,15 @@ const ( var updateFailedCnt = new(atomic.Int32) +func (u *Updates) notificationsEnabled() bool { + return u.instance.Notifications() != nil +} + func notifyUpdateSuccess(force bool) { + if !module.notificationsEnabled() { + return + } + updateFailedCnt.Store(0) module.states.Clear() updateState := registry.GetState().Updates @@ -133,6 +141,10 @@ func getUpdatingInfoMsg() string { } func notifyUpdateCheckFailed(force bool, err error) { + if !module.notificationsEnabled() { + return + } + failedCnt := updateFailedCnt.Add(1) lastSuccess := registry.GetState().Updates.LastSuccessAt diff --git a/spn/access/module.go b/spn/access/module.go index c8cf754fb..2b0765865 100644 --- a/spn/access/module.go +++ b/spn/access/module.go @@ -60,7 +60,7 @@ var ( func prep() error { // Register API handlers. - if conf.Client() { + if conf.Integrated() { err := registerAPIEndpoints() if err != nil { return err @@ -71,34 +71,34 @@ func prep() error { } func start() error { - // Add config listener to enable/disable SPN. - module.instance.Config().EventConfigChange.AddCallback("spn enable check", func(wc *mgr.WorkerCtx, s struct{}) (bool, error) { - // Do not do anything when we are shutting down. - if module.instance.Stopping() { - return true, nil - } + // Initialize zones. + if err := InitializeZones(); err != nil { + return err + } + if conf.Integrated() { + // Add config listener to enable/disable SPN. + module.instance.Config().EventConfigChange.AddCallback("spn enable check", func(wc *mgr.WorkerCtx, s struct{}) (bool, error) { + // Do not do anything when we are shutting down. + if module.instance.Stopping() { + return true, nil + } + + enabled := config.GetAsBool("spn/enable", false) + if enabled() { + module.mgr.Go("ensure SPN is started", module.instance.SPNGroup().EnsureStartedWorker) + } else { + module.mgr.Go("ensure SPN is stopped", module.instance.SPNGroup().EnsureStoppedWorker) + } + return false, nil + }) + + // Check if we need to enable SPN now. enabled := config.GetAsBool("spn/enable", false) if enabled() { module.mgr.Go("ensure SPN is started", module.instance.SPNGroup().EnsureStartedWorker) - } else { - module.mgr.Go("ensure SPN is stopped", module.instance.SPNGroup().EnsureStoppedWorker) } - return false, nil - }) - - // Check if we need to enable SPN now. - enabled := config.GetAsBool("spn/enable", false) - if enabled() { - module.mgr.Go("ensure SPN is started", module.instance.SPNGroup().EnsureStartedWorker) - } - - // Initialize zones. - if err := InitializeZones(); err != nil { - return err - } - if conf.Client() { // Load tokens from database. loadTokens() @@ -110,13 +110,13 @@ func start() error { } func stop() error { - // Make sure SPN is stopped before we proceed. - err := module.mgr.Do("ensure SPN is shut down", module.instance.SPNGroup().EnsureStoppedWorker) - if err != nil { - log.Errorf("access: stop SPN: %s", err) - } + if conf.Integrated() { + // Make sure SPN is stopped before we proceed. + err := module.mgr.Do("ensure SPN is shut down", module.instance.SPNGroup().EnsureStoppedWorker) + if err != nil { + log.Errorf("access: stop SPN: %s", err) + } - if conf.Client() { // Store tokens to database. storeTokens() } @@ -128,7 +128,7 @@ func stop() error { } // UpdateAccount updates the user account and fetches new tokens, if needed. -func UpdateAccount(_ *mgr.WorkerCtx) error { //, task *modules.Task) error { +func UpdateAccount(_ *mgr.WorkerCtx) error { // Schedule next call - this will change if other conditions are met bellow. module.updateAccountWorkerMgr.Delay(24 * time.Hour) diff --git a/spn/access/zones.go b/spn/access/zones.go index 0e550785f..165d8f85f 100644 --- a/spn/access/zones.go +++ b/spn/access/zones.go @@ -48,7 +48,7 @@ func InitializeZones() error { // Special client zone config. var requestSignalHandler func(token.Handler) - if conf.Client() { + if conf.Integrated() { requestSignalHandler = shouldRequestTokensHandler } diff --git a/spn/captain/api.go b/spn/captain/api.go index ec4987670..cfd38de49 100644 --- a/spn/captain/api.go +++ b/spn/captain/api.go @@ -7,6 +7,7 @@ import ( "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/base/database" "github.com/safing/portmaster/base/database/query" + "github.com/safing/portmaster/spn/conf" ) const ( @@ -28,6 +29,10 @@ func registerAPIEndpoints() error { } func handleReInit(ar *api.Request) (msg string, err error) { + if !conf.Client() && !conf.Integrated() { + return "", fmt.Errorf("re-initialization only possible on integrated clients") + } + // Make sure SPN is stopped and wait for it to complete. err = module.mgr.Do("stop SPN for re-init", module.instance.SPNGroup().EnsureStoppedWorker) if err != nil { diff --git a/spn/captain/client.go b/spn/captain/client.go index 0827d317a..00c351bfc 100644 --- a/spn/captain/client.go +++ b/spn/captain/client.go @@ -13,6 +13,7 @@ import ( "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/network/netutils" "github.com/safing/portmaster/spn/access" + "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/crew" "github.com/safing/portmaster/spn/docks" "github.com/safing/portmaster/spn/navigator" @@ -144,7 +145,9 @@ reconnect: netenv.ConnectedToSPN.Set() module.EventSPNConnected.Submit(struct{}{}) - module.mgr.Go("update quick setting countries", navigator.Main.UpdateConfigQuickSettings) + if conf.Integrated() { + module.mgr.Go("update quick setting countries", navigator.Main.UpdateConfigQuickSettings) + } // Reset last health check value, as we have just connected. lastHealthCheck = time.Now() diff --git a/spn/captain/config.go b/spn/captain/config.go index dbfc2563c..bd5b7fc76 100644 --- a/spn/captain/config.go +++ b/spn/captain/config.go @@ -206,7 +206,11 @@ This setting mainly exists for when you need to simulate your presence in anothe } // Config options for use. - cfgOptionRoutingAlgorithm = config.Concurrent.GetAsString(profile.CfgOptionRoutingAlgorithmKey, navigator.DefaultRoutingProfileID) + if conf.Integrated() { + cfgOptionRoutingAlgorithm = config.Concurrent.GetAsString(profile.CfgOptionRoutingAlgorithmKey, navigator.DefaultRoutingProfileID) + } else { + cfgOptionRoutingAlgorithm = func() string { return navigator.DefaultRoutingProfileID } + } return nil } diff --git a/spn/captain/establish.go b/spn/captain/establish.go index ce322bd53..02fc28fc4 100644 --- a/spn/captain/establish.go +++ b/spn/captain/establish.go @@ -28,8 +28,8 @@ func EstablishCrane(callerCtx context.Context, dst *hub.Hub) (*docks.Crane, erro return nil, fmt.Errorf("failed to launch ship: %w", err) } - // On pure clients, mark all ships as public in order to show unmasked data in logs. - if conf.Client() && !conf.PublicHub() { + // If not a public hub, mark all ships as public in order to show unmasked data in logs. + if !conf.PublicHub() { ship.MarkPublic() } diff --git a/spn/captain/module.go b/spn/captain/module.go index 4864f37ab..02742b02e 100644 --- a/spn/captain/module.go +++ b/spn/captain/module.go @@ -21,7 +21,6 @@ import ( "github.com/safing/portmaster/spn/navigator" "github.com/safing/portmaster/spn/patrol" "github.com/safing/portmaster/spn/ships" - _ "github.com/safing/portmaster/spn/sluice" ) const controlledFailureExitCode = 24 @@ -62,7 +61,7 @@ func (c *Captain) SetSleep(enabled bool) { } } -func prep() error { +func (c *Captain) prep() error { // Check if we can parse the bootstrap hub flag. if err := prepBootstrapHubFlag(); err != nil { return err @@ -84,7 +83,7 @@ func prep() error { return err } - module.instance.Patrol().EventChangeSignal.AddCallback( + c.instance.Patrol().EventChangeSignal.AddCallback( "trigger hub status maintenance", func(_ *mgr.WorkerCtx, _ struct{}) (bool, error) { TriggerHubStatusMaintenance() @@ -103,6 +102,33 @@ func start() error { } ships.EnableMasking(maskingBytes) + // Initialize identity and piers. + if conf.PublicHub() { + // Load identity. + if err := loadPublicIdentity(); err != nil { + return fmt.Errorf("load public identity: %w", err) + } + + // Check if any networks are configured. + if !conf.HubHasIPv4() && !conf.HubHasIPv6() { + return errors.New("no IP addresses for Hub configured (or detected)") + } + + // Start management of identity and piers. + if err := prepPublicIdentityMgmt(); err != nil { + return err + } + // Set ID to display on http info page. + ships.DisplayHubID = publicIdentity.ID + // Start listeners. + if err := startPiers(); err != nil { + return err + } + + // Enable connect operation. + crew.EnableConnecting(publicIdentity.Hub) + } + // Initialize intel. module.mgr.Go("start", func(wc *mgr.WorkerCtx) error { if err := registerIntelUpdateHook(); err != nil { @@ -111,70 +137,38 @@ func start() error { if err := updateSPNIntel(module.mgr.Ctx(), nil); err != nil { log.Errorf("spn/captain: failed to update SPN intel: %s", err) } + return nil + }) - // Initialize identity and piers. - if conf.PublicHub() { - // Load identity. - if err := loadPublicIdentity(); err != nil { - // We cannot recover from this, set controlled failure (do not retry). - module.instance.Shutdown(controlledFailureExitCode) + // Subscribe to updates of cranes. + startDockHooks() - return err - } + // bootstrapping + if err := processBootstrapHubFlag(); err != nil { + return err + } + if err := processBootstrapFileFlag(); err != nil { + return err + } - // Check if any networks are configured. - if !conf.HubHasIPv4() && !conf.HubHasIPv6() { - // We cannot recover from this, set controlled failure (do not retry). - module.instance.Shutdown(controlledFailureExitCode) + // network optimizer + if conf.PublicHub() { + module.mgr.Delay("optimize network delay", 15*time.Second, optimizeNetwork).Repeat(1 * time.Minute) + } - return errors.New("no IP addresses for Hub configured (or detected)") - } + // client + home hub manager + if conf.Client() { + module.mgr.Go("client manager", clientManager) - // Start management of identity and piers. - if err := prepPublicIdentityMgmt(); err != nil { - return err - } - // Set ID to display on http info page. - ships.DisplayHubID = publicIdentity.ID - // Start listeners. - if err := startPiers(); err != nil { - return err + // Reset failing hubs when the network changes while not connected. + module.instance.NetEnv().EventNetworkChange.AddCallback("reset failing hubs", func(_ *mgr.WorkerCtx, _ struct{}) (bool, error) { + if ready.IsNotSet() { + navigator.Main.ResetFailingStates() } + return false, nil + }) + } - // Enable connect operation. - crew.EnableConnecting(publicIdentity.Hub) - } - - // Subscribe to updates of cranes. - startDockHooks() - - // bootstrapping - if err := processBootstrapHubFlag(); err != nil { - return err - } - if err := processBootstrapFileFlag(); err != nil { - return err - } - - // network optimizer - if conf.PublicHub() { - module.mgr.Delay("optimize network delay", 15*time.Second, optimizeNetwork).Repeat(1 * time.Minute) - } - - // client + home hub manager - if conf.Client() { - module.mgr.Go("client manager", clientManager) - - // Reset failing hubs when the network changes while not connected. - module.instance.NetEnv().EventNetworkChange.AddCallback("reset failing hubs", func(_ *mgr.WorkerCtx, _ struct{}) (bool, error) { - if ready.IsNotSet() { - navigator.Main.ResetFailingStates() - } - return false, nil - }) - } - return nil - }) return nil } @@ -236,7 +230,7 @@ func New(instance instance) (*Captain, error) { maintainPublicStatus: m.NewWorkerMgr("maintain public status", maintainPublicStatus, nil), } - if err := prep(); err != nil { + if err := module.prep(); err != nil { return nil, err } diff --git a/spn/conf/mode.go b/spn/conf/mode.go index cc1248bbe..295af6440 100644 --- a/spn/conf/mode.go +++ b/spn/conf/mode.go @@ -1,30 +1,41 @@ package conf import ( - "github.com/tevino/abool" + "sync/atomic" ) var ( - publicHub = abool.New() - client = abool.New() + publicHub atomic.Bool + client atomic.Bool + integrated atomic.Bool ) // PublicHub returns whether this is a public Hub. func PublicHub() bool { - return publicHub.IsSet() + return publicHub.Load() } // EnablePublicHub enables the public hub mode. func EnablePublicHub(enable bool) { - publicHub.SetTo(enable) + publicHub.Store(enable) } // Client returns whether this is a client. func Client() bool { - return client.IsSet() + return client.Load() } // EnableClient enables the client mode. func EnableClient(enable bool) { - client.SetTo(enable) + client.Store(enable) +} + +// Integrated returns whether SPN is running integrated into Portmaster. +func Integrated() bool { + return integrated.Load() +} + +// EnableIntegration enables the integrated mode. +func EnableIntegration(enable bool) { + integrated.Store(enable) } diff --git a/spn/instance.go b/spn/instance.go new file mode 100644 index 000000000..96866b52b --- /dev/null +++ b/spn/instance.go @@ -0,0 +1,408 @@ +package spn + +import ( + "context" + "fmt" + "sync/atomic" + "time" + + "github.com/safing/portmaster/base/api" + "github.com/safing/portmaster/base/config" + "github.com/safing/portmaster/base/database/dbmodule" + "github.com/safing/portmaster/base/metrics" + "github.com/safing/portmaster/base/notifications" + "github.com/safing/portmaster/base/rng" + "github.com/safing/portmaster/service/core" + "github.com/safing/portmaster/service/core/base" + "github.com/safing/portmaster/service/intel/filterlists" + "github.com/safing/portmaster/service/intel/geoip" + "github.com/safing/portmaster/service/mgr" + "github.com/safing/portmaster/service/netenv" + "github.com/safing/portmaster/service/updates" + "github.com/safing/portmaster/spn/access" + "github.com/safing/portmaster/spn/cabin" + "github.com/safing/portmaster/spn/captain" + "github.com/safing/portmaster/spn/crew" + "github.com/safing/portmaster/spn/docks" + "github.com/safing/portmaster/spn/navigator" + "github.com/safing/portmaster/spn/patrol" + "github.com/safing/portmaster/spn/ships" + "github.com/safing/portmaster/spn/sluice" + "github.com/safing/portmaster/spn/terminal" +) + +// Instance is an instance of a Portmaster service. +type Instance struct { + ctx context.Context + cancelCtx context.CancelFunc + serviceGroup *mgr.Group + + exitCode atomic.Int32 + + base *base.Base + database *dbmodule.DBModule + config *config.Config + api *api.API + metrics *metrics.Metrics + rng *rng.Rng + + core *core.Core + updates *updates.Updates + geoip *geoip.GeoIP + netenv *netenv.NetEnv + filterLists *filterlists.FilterLists + + access *access.Access + cabin *cabin.Cabin + navigator *navigator.Navigator + captain *captain.Captain + crew *crew.Crew + docks *docks.Docks + patrol *patrol.Patrol + ships *ships.Ships + sluice *sluice.SluiceModule + terminal *terminal.TerminalModule + + CommandLineOperation func() error +} + +// New returns a new Portmaster service instance. +func New() (*Instance, error) { + // Create instance to pass it to modules. + instance := &Instance{} + instance.ctx, instance.cancelCtx = context.WithCancel(context.Background()) + + var err error + + // Base modules + instance.base, err = base.New(instance) + if err != nil { + return instance, fmt.Errorf("create base module: %w", err) + } + instance.database, err = dbmodule.New(instance) + if err != nil { + return instance, fmt.Errorf("create database module: %w", err) + } + instance.config, err = config.New(instance) + if err != nil { + return instance, fmt.Errorf("create config module: %w", err) + } + instance.api, err = api.New(instance) + if err != nil { + return instance, fmt.Errorf("create api module: %w", err) + } + instance.metrics, err = metrics.New(instance) + if err != nil { + return instance, fmt.Errorf("create metrics module: %w", err) + } + instance.rng, err = rng.New(instance) + if err != nil { + return instance, fmt.Errorf("create rng module: %w", err) + } + + // Service modules + instance.core, err = core.New(instance) + if err != nil { + return instance, fmt.Errorf("create core module: %w", err) + } + instance.updates, err = updates.New(instance) + if err != nil { + return instance, fmt.Errorf("create updates module: %w", err) + } + instance.geoip, err = geoip.New(instance) + if err != nil { + return instance, fmt.Errorf("create customlist module: %w", err) + } + instance.netenv, err = netenv.New(instance) + if err != nil { + return instance, fmt.Errorf("create netenv module: %w", err) + } + instance.filterLists, err = filterlists.New(instance) + if err != nil { + return instance, fmt.Errorf("create filterLists module: %w", err) + } + + // SPN modules + instance.access, err = access.New(instance) + if err != nil { + return instance, fmt.Errorf("create access module: %w", err) + } + instance.cabin, err = cabin.New(instance) + if err != nil { + return instance, fmt.Errorf("create cabin module: %w", err) + } + instance.navigator, err = navigator.New(instance) + if err != nil { + return instance, fmt.Errorf("create navigator module: %w", err) + } + instance.crew, err = crew.New(instance) + if err != nil { + return instance, fmt.Errorf("create crew module: %w", err) + } + instance.docks, err = docks.New(instance) + if err != nil { + return instance, fmt.Errorf("create docks module: %w", err) + } + instance.patrol, err = patrol.New(instance) + if err != nil { + return instance, fmt.Errorf("create patrol module: %w", err) + } + instance.ships, err = ships.New(instance) + if err != nil { + return instance, fmt.Errorf("create ships module: %w", err) + } + instance.sluice, err = sluice.New(instance) + if err != nil { + return instance, fmt.Errorf("create sluice module: %w", err) + } + instance.terminal, err = terminal.New(instance) + if err != nil { + return instance, fmt.Errorf("create terminal module: %w", err) + } + instance.captain, err = captain.New(instance) + if err != nil { + return instance, fmt.Errorf("create captain module: %w", err) + } + + // Add all modules to instance group. + instance.serviceGroup = mgr.NewGroup( + instance.base, + instance.database, + instance.config, + instance.api, + instance.metrics, + instance.rng, + + instance.core, + instance.updates, + instance.geoip, + instance.netenv, + + instance.access, + instance.cabin, + instance.navigator, + instance.captain, + instance.crew, + instance.docks, + instance.patrol, + instance.ships, + instance.sluice, + instance.terminal, + ) + + return instance, nil +} + +// SleepyModule is an interface for modules that can enter some sort of sleep mode. +type SleepyModule interface { + SetSleep(enabled bool) +} + +// SetSleep sets sleep mode on all modules that satisfy the SleepyModule interface. +func (i *Instance) SetSleep(enabled bool) { + for _, module := range i.serviceGroup.Modules() { + if sm, ok := module.(SleepyModule); ok { + sm.SetSleep(enabled) + } + } +} + +// Database returns the database module. +func (i *Instance) Database() *dbmodule.DBModule { + return i.database +} + +// Config returns the config module. +func (i *Instance) Config() *config.Config { + return i.config +} + +// API returns the api module. +func (i *Instance) API() *api.API { + return i.api +} + +// Metrics returns the metrics module. +func (i *Instance) Metrics() *metrics.Metrics { + return i.metrics +} + +// Rng returns the rng module. +func (i *Instance) Rng() *rng.Rng { + return i.rng +} + +// Base returns the base module. +func (i *Instance) Base() *base.Base { + return i.base +} + +// Updates returns the updates module. +func (i *Instance) Updates() *updates.Updates { + return i.updates +} + +// GeoIP returns the geoip module. +func (i *Instance) GeoIP() *geoip.GeoIP { + return i.geoip +} + +// NetEnv returns the netenv module. +func (i *Instance) NetEnv() *netenv.NetEnv { + return i.netenv +} + +// Access returns the access module. +func (i *Instance) Access() *access.Access { + return i.access +} + +// Cabin returns the cabin module. +func (i *Instance) Cabin() *cabin.Cabin { + return i.cabin +} + +// Captain returns the captain module. +func (i *Instance) Captain() *captain.Captain { + return i.captain +} + +// Crew returns the crew module. +func (i *Instance) Crew() *crew.Crew { + return i.crew +} + +// Docks returns the crew module. +func (i *Instance) Docks() *docks.Docks { + return i.docks +} + +// Navigator returns the navigator module. +func (i *Instance) Navigator() *navigator.Navigator { + return i.navigator +} + +// Patrol returns the patrol module. +func (i *Instance) Patrol() *patrol.Patrol { + return i.patrol +} + +// Ships returns the ships module. +func (i *Instance) Ships() *ships.Ships { + return i.ships +} + +// Sluice returns the ships module. +func (i *Instance) Sluice() *sluice.SluiceModule { + return i.sluice +} + +// Terminal returns the terminal module. +func (i *Instance) Terminal() *terminal.TerminalModule { + return i.terminal +} + +// FilterLists returns the filterLists module. +func (i *Instance) FilterLists() *filterlists.FilterLists { + return i.filterLists +} + +// Core returns the core module. +func (i *Instance) Core() *core.Core { + return i.core +} + +// Events + +// GetEventSPNConnected return the event manager for the SPN connected event. +func (i *Instance) GetEventSPNConnected() *mgr.EventMgr[struct{}] { + return i.captain.EventSPNConnected +} + +// Special functions + +// SetCmdLineOperation sets a command line operation to be executed instead of starting the system. This is useful when functions need all modules to be prepared for a special operation. +func (i *Instance) SetCmdLineOperation(f func() error) { + i.CommandLineOperation = f +} + +// GetStates returns the current states of all group modules. +func (i *Instance) GetStates() []mgr.StateUpdate { + return i.serviceGroup.GetStates() +} + +// AddStatesCallback adds the given callback function to all group modules that +// expose a state manager at States(). +func (i *Instance) AddStatesCallback(callbackName string, callback mgr.EventCallbackFunc[mgr.StateUpdate]) { + i.serviceGroup.AddStatesCallback(callbackName, callback) +} + +// Ready returns whether all modules in the main service module group have been started and are still running. +func (i *Instance) Ready() bool { + return i.serviceGroup.Ready() +} + +// Ctx returns the instance context. +// It is only canceled on shutdown. +func (i *Instance) Ctx() context.Context { + return i.ctx +} + +// Start starts the instance. +func (i *Instance) Start() error { + return i.serviceGroup.Start() +} + +// Stop stops the instance and cancels the instance context when done. +func (i *Instance) Stop() error { + defer i.cancelCtx() + return i.serviceGroup.Stop() +} + +// Shutdown asynchronously stops the instance. +func (i *Instance) Shutdown(exitCode int) { + i.exitCode.Store(int32(exitCode)) + + m := mgr.New("instance") + m.Go("shutdown", func(w *mgr.WorkerCtx) error { + for { + if err := i.Stop(); err != nil { + w.Error("failed to shutdown", "err", err, "retry", "1s") + time.Sleep(1 * time.Second) + } else { + return nil + } + } + }) +} + +// Stopping returns whether the instance is shutting down. +func (i *Instance) Stopping() bool { + return i.ctx.Err() == nil +} + +// Stopped returns a channel that is triggered when the instance has shut down. +func (i *Instance) Stopped() <-chan struct{} { + return i.ctx.Done() +} + +// SetExitCode sets the exit code on the instance. +func (i *Instance) SetExitCode(exitCode int) { + i.exitCode.Store(int32(exitCode)) +} + +// ExitCode returns the set exit code of the instance. +func (i *Instance) ExitCode() int { + return int(i.exitCode.Load()) +} + +// SPNGroup fakes interface conformance. +// SPNGroup is only needed on SPN clients. +func (i *Instance) SPNGroup() *mgr.ExtendedGroup { + return nil +} + +// Unsupported Modules. + +// Notifications returns nil. +func (i *Instance) Notifications() *notifications.Notifications { return nil } diff --git a/spn/navigator/module.go b/spn/navigator/module.go index 0d8f235c1..2568cb7e3 100644 --- a/spn/navigator/module.go +++ b/spn/navigator/module.go @@ -71,9 +71,14 @@ func prep() error { func start() error { Main = NewMap(conf.MainMapName, true) devMode = config.Concurrent.GetAsBool(config.CfgDevModeKey, false) - cfgOptionRoutingAlgorithm = config.Concurrent.GetAsString(cfgOptionRoutingAlgorithmKey, DefaultRoutingProfileID) cfgOptionTrustNodeNodes = config.Concurrent.GetAsStringArray(cfgOptionTrustNodeNodesKey, []string{}) + if conf.Integrated() { + cfgOptionRoutingAlgorithm = config.Concurrent.GetAsString(cfgOptionRoutingAlgorithmKey, DefaultRoutingProfileID) + } else { + cfgOptionRoutingAlgorithm = func() string { return DefaultRoutingProfileID } + } + err := registerMapDatabase() if err != nil { return err diff --git a/spn/ships/http_shared.go b/spn/ships/http_shared.go index eddbea232..bae861e57 100644 --- a/spn/ships/http_shared.go +++ b/spn/ships/http_shared.go @@ -97,8 +97,7 @@ func addHTTPHandler(port uint16, path string, handler http.HandlerFunc) error { WriteTimeout: 1 * time.Minute, IdleTimeout: 1 * time.Minute, MaxHeaderBytes: 4096, - // ErrorLog: &log.Logger{}, // FIXME - BaseContext: func(net.Listener) context.Context { return module.mgr.Ctx() }, + BaseContext: func(net.Listener) context.Context { return module.mgr.Ctx() }, } shared.server = server diff --git a/spn/sluice/module.go b/spn/sluice/module.go index 99197dea5..88033d479 100644 --- a/spn/sluice/module.go +++ b/spn/sluice/module.go @@ -39,7 +39,7 @@ func start() error { // Listening on all interfaces for now, as we need this for Windows. // Handle similarly to the nameserver listener. - if conf.Client() && EnableListener { + if conf.Integrated() && EnableListener { StartSluice("tcp4", "0.0.0.0:717") StartSluice("udp4", "0.0.0.0:717") From 696481a1b746bf84c05f95f38b5f772dd9611081 Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 30 Jul 2024 17:08:05 +0200 Subject: [PATCH 46/56] Fix log sources --- service/mgr/manager.go | 69 ++++++++++++++++++++++++++++++++++-------- service/mgr/worker.go | 56 ++++++++++++++++++++++++++-------- 2 files changed, 101 insertions(+), 24 deletions(-) diff --git a/service/mgr/manager.go b/service/mgr/manager.go index 4be527854..2070a2909 100644 --- a/service/mgr/manager.go +++ b/service/mgr/manager.go @@ -3,6 +3,7 @@ package mgr import ( "context" "log/slog" + "runtime" "sync/atomic" "time" ) @@ -78,49 +79,91 @@ func (m *Manager) LogEnabled(level slog.Level) bool { // Debug logs at LevelDebug. // The manager context is automatically supplied. func (m *Manager) Debug(msg string, args ...any) { - m.logger.DebugContext(m.ctx, msg, args...) + if !m.logger.Enabled(m.ctx, slog.LevelDebug) { + return + } + m.writeLog(slog.LevelDebug, msg, args...) } // Info logs at LevelInfo. // The manager context is automatically supplied. func (m *Manager) Info(msg string, args ...any) { - m.logger.InfoContext(m.ctx, msg, args...) + if !m.logger.Enabled(m.ctx, slog.LevelInfo) { + return + } + m.writeLog(slog.LevelInfo, msg, args...) } // Warn logs at LevelWarn. // The manager context is automatically supplied. func (m *Manager) Warn(msg string, args ...any) { - m.logger.WarnContext(m.ctx, msg, args...) + if !m.logger.Enabled(m.ctx, slog.LevelWarn) { + return + } + m.writeLog(slog.LevelWarn, msg, args...) } // Error logs at LevelError. // The manager context is automatically supplied. func (m *Manager) Error(msg string, args ...any) { - m.logger.ErrorContext(m.ctx, msg, args...) + if !m.logger.Enabled(m.ctx, slog.LevelError) { + return + } + m.writeLog(slog.LevelError, msg, args...) } // Log emits a log record with the current time and the given level and message. // The manager context is automatically supplied. func (m *Manager) Log(level slog.Level, msg string, args ...any) { - m.logger.Log(m.ctx, level, msg, args...) + if !m.logger.Enabled(m.ctx, level) { + return + } + m.writeLog(level, msg, args...) } // LogAttrs is a more efficient version of Log() that accepts only Attrs. // The manager context is automatically supplied. func (m *Manager) LogAttrs(level slog.Level, msg string, attrs ...slog.Attr) { - m.logger.LogAttrs(m.ctx, level, msg, attrs...) + if !m.logger.Enabled(m.ctx, level) { + return + } + + var pcs [1]uintptr + runtime.Callers(2, pcs[:]) // skip "Callers" and "LogAttrs". + r := slog.NewRecord(time.Now(), level, msg, pcs[0]) + r.AddAttrs(attrs...) + _ = m.logger.Handler().Handle(m.ctx, r) +} + +func (m *Manager) writeLog(level slog.Level, msg string, args ...any) { + var pcs [1]uintptr + runtime.Callers(3, pcs[:]) // skip "Callers", "writeLog" and the calling function. + r := slog.NewRecord(time.Now(), level, msg, pcs[0]) + r.Add(args...) + _ = m.logger.Handler().Handle(m.ctx, r) } // WaitForWorkers waits for all workers of this manager to be done. // The default maximum waiting time is one minute. func (m *Manager) WaitForWorkers(max time.Duration) (done bool) { + return m.waitForWorkers(max, 0) +} + +// WaitForWorkersFromStop is a special version of WaitForWorkers, meant to be called from the stop routine. +// It waits for all workers of this manager to be done, except for the Stop function. +// The default maximum waiting time is one minute. +func (m *Manager) WaitForWorkersFromStop(max time.Duration) (done bool) { + return m.waitForWorkers(max, 1) +} + +func (m *Manager) waitForWorkers(max time.Duration, limit int32) (done bool) { // Return immediately if there are no workers. - if m.workerCnt.Load() == 0 { + if m.workerCnt.Load() <= limit { return true } // Setup timers. - reCheckDuration := 100 * time.Millisecond + reCheckDuration := 10 * time.Millisecond if max <= 0 { max = time.Minute } @@ -131,13 +174,15 @@ func (m *Manager) WaitForWorkers(max time.Duration) (done bool) { // Wait for workers to finish, plus check the count in intervals. for { - if m.workerCnt.Load() == 0 { + if m.workerCnt.Load() <= limit { return true } select { case <-m.workersDone: - return true + if m.workerCnt.Load() <= limit { + return true + } case <-reCheck.C: // Check worker count again. @@ -146,7 +191,7 @@ func (m *Manager) WaitForWorkers(max time.Duration) (done bool) { reCheck.Reset(reCheckDuration) case <-maxWait.C: - return m.workerCnt.Load() == 0 + return m.workerCnt.Load() <= limit } } } @@ -156,7 +201,7 @@ func (m *Manager) workerStart() { } func (m *Manager) workerDone() { - if m.workerCnt.Add(-1) == 0 { + if m.workerCnt.Add(-1) <= 1 { // Notify all waiters. for { select { diff --git a/service/mgr/worker.go b/service/mgr/worker.go index 9f7eb2ee0..510dedfef 100644 --- a/service/mgr/worker.go +++ b/service/mgr/worker.go @@ -6,6 +6,7 @@ import ( "fmt" "log/slog" "os" + "runtime" "runtime/debug" "strings" "time" @@ -83,37 +84,68 @@ func (w *WorkerCtx) LogEnabled(level slog.Level) bool { // Debug logs at LevelDebug. // The worker context is automatically supplied. func (w *WorkerCtx) Debug(msg string, args ...any) { - w.logger.DebugContext(w.ctx, msg, args...) + if !w.logger.Enabled(w.ctx, slog.LevelDebug) { + return + } + w.writeLog(slog.LevelDebug, msg, args...) } // Info logs at LevelInfo. // The worker context is automatically supplied. func (w *WorkerCtx) Info(msg string, args ...any) { - w.logger.InfoContext(w.ctx, msg, args...) + if !w.logger.Enabled(w.ctx, slog.LevelInfo) { + return + } + w.writeLog(slog.LevelInfo, msg, args...) } // Warn logs at LevelWarn. // The worker context is automatically supplied. func (w *WorkerCtx) Warn(msg string, args ...any) { - w.logger.WarnContext(w.ctx, msg, args...) + if !w.logger.Enabled(w.ctx, slog.LevelWarn) { + return + } + w.writeLog(slog.LevelWarn, msg, args...) } // Error logs at LevelError. // The worker context is automatically supplied. func (w *WorkerCtx) Error(msg string, args ...any) { - w.logger.ErrorContext(w.ctx, msg, args...) + if !w.logger.Enabled(w.ctx, slog.LevelError) { + return + } + w.writeLog(slog.LevelError, msg, args...) } // Log emits a log record with the current time and the given level and message. // The worker context is automatically supplied. func (w *WorkerCtx) Log(level slog.Level, msg string, args ...any) { - w.logger.Log(w.ctx, level, msg, args...) + if !w.logger.Enabled(w.ctx, level) { + return + } + w.writeLog(level, msg, args...) } // LogAttrs is a more efficient version of Log() that accepts only Attrs. // The worker context is automatically supplied. func (w *WorkerCtx) LogAttrs(level slog.Level, msg string, attrs ...slog.Attr) { - w.logger.LogAttrs(w.ctx, level, msg, attrs...) + if !w.logger.Enabled(w.ctx, level) { + return + } + + var pcs [1]uintptr + runtime.Callers(2, pcs[:]) // skip "Callers" and "LogAttrs". + r := slog.NewRecord(time.Now(), level, msg, pcs[0]) + r.AddAttrs(attrs...) + _ = w.logger.Handler().Handle(w.ctx, r) +} + +func (w *WorkerCtx) writeLog(level slog.Level, msg string, args ...any) { + var pcs [1]uintptr + runtime.Callers(3, pcs[:]) // skip "Callers", "writeLog" and the calling function. + r := slog.NewRecord(time.Now(), level, msg, pcs[0]) + r.Add(args...) + _ = w.logger.Handler().Handle(w.ctx, r) } // Go starts the given function in a goroutine (as a "worker"). @@ -157,13 +189,13 @@ func (m *Manager) manageWorker(name string, fn func(w *WorkerCtx) error) { // If manager is stopping, just log error and return. if m.IsDone() { if panicInfo != "" { - m.Error( + w.Error( "worker failed", "err", err, "file", panicInfo, ) } else { - m.Error( + w.Error( "worker failed", "err", err, ) @@ -180,7 +212,7 @@ func (m *Manager) manageWorker(name string, fn func(w *WorkerCtx) error) { // Log error and retry after backoff duration. if panicInfo != "" { - m.Error( + w.Error( "worker failed", "failCnt", failCnt, "backoff", backoff, @@ -188,7 +220,7 @@ func (m *Manager) manageWorker(name string, fn func(w *WorkerCtx) error) { "file", panicInfo, ) } else { - m.Error( + w.Error( "worker failed", "failCnt", failCnt, "backoff", backoff, @@ -235,13 +267,13 @@ func (m *Manager) Do(name string, fn func(w *WorkerCtx) error) error { default: // Log error and return. if panicInfo != "" { - m.Error( + w.Error( "worker failed", "err", err, "file", panicInfo, ) } else { - m.Error( + w.Error( "worker failed", "err", err, ) From b1db2e94a940d9d8bb7e7dcecea8f7efd5c65ffa Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 30 Jul 2024 17:08:21 +0200 Subject: [PATCH 47/56] Make worker mgr less error prone --- service/mgr/workermgr.go | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/service/mgr/workermgr.go b/service/mgr/workermgr.go index 0b40a3a97..55ba3b3ac 100644 --- a/service/mgr/workermgr.go +++ b/service/mgr/workermgr.go @@ -192,9 +192,10 @@ manage: // Run worker. wCtx := &WorkerCtx{ - logger: s.mgr.logger.With("worker", s.name), + workerMgr: s, + logger: s.ctx.logger, } - wCtx.ctx, wCtx.cancelCtx = context.WithCancel(s.mgr.Ctx()) + wCtx.ctx, wCtx.cancelCtx = context.WithCancel(s.ctx.ctx) panicInfo, err := s.mgr.runWorker(wCtx, s.fn) switch { @@ -207,13 +208,13 @@ manage: default: // Log error and return. if panicInfo != "" { - s.ctx.Error( + wCtx.Error( "worker failed", "err", err, "file", panicInfo, ) } else { - s.ctx.Error( + wCtx.Error( "worker failed", "err", err, ) @@ -231,11 +232,21 @@ manage: // Go executes the worker immediately. // If the worker is currently being executed, // the next execution will commence afterwards. -// Can only be called after calling one of Delay(), Repeat() or KeepAlive(). +// Calling Go() implies KeepAlive() if nothing else was specified yet. func (s *WorkerMgr) Go() { s.actionLock.Lock() defer s.actionLock.Unlock() + // Check if any action is already defined. + switch { + case s.delay != nil: + case s.repeat != nil: + case s.keepAlive != nil: + default: + s.keepAlive = &workerMgrNoop{} + s.check() + } + // Reset repeat if set. s.repeat.Reset() @@ -295,6 +306,8 @@ func (s *WorkerMgr) KeepAlive() *WorkerMgr { defer s.actionLock.Unlock() s.keepAlive = &workerMgrNoop{} + + s.check() return s } From 4f76f4343692651ad246494bfa418d6c0afefbf7 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 2 Aug 2024 16:53:09 +0200 Subject: [PATCH 48/56] Fix tests and minor issues --- base/info/module/flags.go | 62 ---------------------------- base/rng/test/main.go | 4 +- go.mod | 4 +- go.sum | 6 +++ spn/cabin/identity_test.go | 32 ++++++++++---- spn/docks/terminal_expansion_test.go | 22 ++++++---- 6 files changed, 50 insertions(+), 80 deletions(-) delete mode 100644 base/info/module/flags.go diff --git a/base/info/module/flags.go b/base/info/module/flags.go deleted file mode 100644 index 59e974ae7..000000000 --- a/base/info/module/flags.go +++ /dev/null @@ -1,62 +0,0 @@ -package module - -import ( - "errors" - "flag" - "fmt" - "sync/atomic" - - "github.com/safing/portmaster/base/info" - "github.com/safing/portmaster/service/core/base" - "github.com/safing/portmaster/service/mgr" -) - -type Info struct { - instance instance -} - -var showVersion bool - -func init() { - flag.BoolVar(&showVersion, "version", false, "show version and exit") -} - -func (i *Info) Start(m *mgr.Manager) error { - err := info.CheckVersion() - if err != nil { - return err - } - - if printVersion() { - return base.ErrCleanExit - } - return nil -} - -func (i *Info) Stop(m *mgr.Manager) error { - return nil -} - -// printVersion prints the version, if requested, and returns if it did so. -func printVersion() (printed bool) { - if showVersion { - fmt.Println(info.FullVersion()) - return true - } - return false -} - -var shimLoaded atomic.Bool - -func New(instance instance) (*Info, error) { - if !shimLoaded.CompareAndSwap(false, true) { - return nil, errors.New("only one instance allowed") - } - module := &Info{ - instance: instance, - } - - return module, nil -} - -type instance interface{} diff --git a/base/rng/test/main.go b/base/rng/test/main.go index 896d86ae7..0778a1376 100644 --- a/base/rng/test/main.go +++ b/base/rng/test/main.go @@ -17,7 +17,6 @@ import ( "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/rng" - "github.com/safing/portmaster/service/core/base" "github.com/safing/portmaster/service/mgr" ) @@ -57,8 +56,7 @@ func main() { func prep() error { if len(os.Args) < 3 { - fmt.Printf("usage: ./%s {fortuna|tickfeeder} [output size in MB]", os.Args[0]) - return base.ErrCleanExit + return fmt.Errorf("usage: ./%s {fortuna|tickfeeder} [output size in MB]", os.Args[0]) } switch os.Args[1] { diff --git a/go.mod b/go.mod index ac763a24e..4001b4c28 100644 --- a/go.mod +++ b/go.mod @@ -34,7 +34,10 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-version v1.7.0 github.com/jackc/puddle/v2 v2.2.1 + github.com/lmittmann/tint v1.0.5 github.com/mat/besticon v3.12.0+incompatible + github.com/mattn/go-colorable v0.1.13 + github.com/mattn/go-isatty v0.0.20 github.com/miekg/dns v1.1.61 github.com/mitchellh/copystructure v1.2.0 github.com/mitchellh/go-server-timing v1.0.1 @@ -92,7 +95,6 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/josharian/native v1.1.0 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect - github.com/mattn/go-isatty v0.0.20 // indirect github.com/mdlayher/netlink v1.7.2 // indirect github.com/mdlayher/socket v0.5.1 // indirect github.com/mitchellh/reflectwalk v1.0.2 // indirect diff --git a/go.sum b/go.sum index 144bd2981..8e308fd8e 100644 --- a/go.sum +++ b/go.sum @@ -175,12 +175,17 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lmittmann/tint v1.0.5 h1:NQclAutOfYsqs2F1Lenue6OoWCajs5wJcP3DfWVpePw= +github.com/lmittmann/tint v1.0.5/go.mod h1:HIS3gSy7qNwGCj+5oRjAutErFBl4BzdQP6cJZ0NfMwE= github.com/magiconair/properties v1.7.4-0.20170902060319-8d7837e64d3c/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/mat/besticon v3.12.0+incompatible h1:1KTD6wisfjfnX+fk9Kx/6VEZL+MAW1LhCkL9Q47H9Bg= github.com/mat/besticon v3.12.0+incompatible/go.mod h1:mA1auQYHt6CW5e7L9HJLmqVQC8SzNk2gVwouO0AbiEU= github.com/mattn/go-colorable v0.0.10-0.20170816031813-ad5389df28cd/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.2/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mdlayher/ethtool v0.0.0-20210210192532-2b88debcdd43/go.mod h1:+t7E0lkKfbBsebllff1xdTmyJt8lH37niI6kwFk9OTo= @@ -420,6 +425,7 @@ golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211210111614-af8b64212486/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/spn/cabin/identity_test.go b/spn/cabin/identity_test.go index e0e9ea4f2..cfad082cd 100644 --- a/spn/cabin/identity_test.go +++ b/spn/cabin/identity_test.go @@ -19,7 +19,7 @@ func TestIdentity(t *testing.T) { } // Create new identity. - identityTestKey := "core:spn/public/identity" + identityTestKey := "core:spn/public/identity-test" id, err := CreateIdentity(module.m.Ctx(), conf.MainMapName) if err != nil { t.Fatal(err) @@ -111,19 +111,37 @@ func TestIdentity(t *testing.T) { t.Fatal(err) } + // Ensure the Measurements reset the values. + measurements := id.Hub.GetMeasurements() + measurements.SetLatency(0) + measurements.SetCapacity(0) + measurements.SetCalculatedCost(hub.MaxCalculatedCost) + // Save to and load from database. err = id.Save() if err != nil { t.Fatal(err) } - id2, changed, err := LoadIdentity(identityTestKey) + id2, _, err := LoadIdentity(identityTestKey) if err != nil { t.Fatal(err) } - if changed { - t.Error("unexpected change") - } - // Check if they match - assert.Equal(t, id, id2, "identities should be equal") + // Reset everything that should not be compared. + id.infoExportCache = nil + id2.infoExportCache = nil + id.statusExportCache = nil + id2.statusExportCache = nil + id.ExchKeys = nil + id2.ExchKeys = nil + id.Hub.Status = nil + id2.Hub.Status = nil + id.Hub.PublicKey = nil + id2.Hub.PublicKey = nil + + // Check important aspects of the identities. + assert.Equal(t, id.ID, id2.ID, "identity IDs must be equal") + assert.Equal(t, id.Map, id2.Map, "identity Maps should be equal") + assert.Equal(t, id.Hub, id2.Hub, "identity Hubs should be equal") + assert.Equal(t, id.Signet, id2.Signet, "identity Signets should be equal") } diff --git a/spn/docks/terminal_expansion_test.go b/spn/docks/terminal_expansion_test.go index 7dfb38b81..656f5fdcc 100644 --- a/spn/docks/terminal_expansion_test.go +++ b/spn/docks/terminal_expansion_test.go @@ -1,6 +1,7 @@ package docks import ( + "context" "fmt" "os" "runtime/pprof" @@ -93,6 +94,8 @@ func testExpansion( //nolint:maintidx,thelper var crane1, crane2to1, crane2to3, crane3to2, crane3to4, crane4 *Crane var craneWg sync.WaitGroup + started := time.Now() + craneCtx, cancelCraneCtx := context.WithCancel(context.Background()) craneWg.Add(6) go func() { @@ -102,7 +105,7 @@ func testExpansion( //nolint:maintidx,thelper panic(fmt.Sprintf("expansion test %s could not create crane1: %s", testID, err)) } crane1.ID = "c1" - err = crane1.Start(module.mgr.Ctx()) + err = crane1.Start(craneCtx) if err != nil { panic(fmt.Sprintf("expansion test %s could not start crane1: %s", testID, err)) } @@ -116,7 +119,7 @@ func testExpansion( //nolint:maintidx,thelper panic(fmt.Sprintf("expansion test %s could not create crane2to1: %s", testID, err)) } crane2to1.ID = "c2to1" - err = crane2to1.Start(module.mgr.Ctx()) + err = crane2to1.Start(craneCtx) if err != nil { panic(fmt.Sprintf("expansion test %s could not start crane2to1: %s", testID, err)) } @@ -130,7 +133,7 @@ func testExpansion( //nolint:maintidx,thelper panic(fmt.Sprintf("expansion test %s could not create crane2to3: %s", testID, err)) } crane2to3.ID = "c2to3" - err = crane2to3.Start(module.mgr.Ctx()) + err = crane2to3.Start(craneCtx) if err != nil { panic(fmt.Sprintf("expansion test %s could not start crane2to3: %s", testID, err)) } @@ -144,7 +147,7 @@ func testExpansion( //nolint:maintidx,thelper panic(fmt.Sprintf("expansion test %s could not create crane3to2: %s", testID, err)) } crane3to2.ID = "c3to2" - err = crane3to2.Start(module.mgr.Ctx()) + err = crane3to2.Start(craneCtx) if err != nil { panic(fmt.Sprintf("expansion test %s could not start crane3to2: %s", testID, err)) } @@ -158,7 +161,7 @@ func testExpansion( //nolint:maintidx,thelper panic(fmt.Sprintf("expansion test %s could not create crane3to4: %s", testID, err)) } crane3to4.ID = "c3to4" - err = crane3to4.Start(module.mgr.Ctx()) + err = crane3to4.Start(craneCtx) if err != nil { panic(fmt.Sprintf("expansion test %s could not start crane3to4: %s", testID, err)) } @@ -172,7 +175,7 @@ func testExpansion( //nolint:maintidx,thelper panic(fmt.Sprintf("expansion test %s could not create crane4: %s", testID, err)) } crane4.ID = "c4" - err = crane4.Start(module.mgr.Ctx()) + err = crane4.Start(craneCtx) if err != nil { panic(fmt.Sprintf("expansion test %s could not start crane4: %s", testID, err)) } @@ -288,13 +291,18 @@ func testExpansion( //nolint:maintidx,thelper op1.Wait() } - // Wait for completion. + // Wait for double the time, so that the counters can complete in both directions. + time.Sleep(time.Since(started)) + + // Signal completion. close(finished) // Wait a little so that all errors can be propagated, so we can truly see // if we succeeded. time.Sleep(100 * time.Millisecond) + cancelCraneCtx() + // Check errors. if op1.Error != nil { t.Fatalf("crane test %s counter op1 failed: %s", testID, op1.Error) From 33f60303b5041ea644f34893f63beb0b2c8866c6 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 2 Aug 2024 16:53:45 +0200 Subject: [PATCH 49/56] Fix observation hub --- cmds/observation-hub/main.go | 8 +++++ cmds/observation-hub/observe.go | 41 ++++++++++++++++-------- service/mgr/group.go | 57 ++++++++++++++++++--------------- spn/instance.go | 46 ++++++++++++++++++++++---- spn/navigator/pin_export.go | 23 +++++++++++++ 5 files changed, 130 insertions(+), 45 deletions(-) diff --git a/cmds/observation-hub/main.go b/cmds/observation-hub/main.go index b7a8ed712..ecc0cd55e 100644 --- a/cmds/observation-hub/main.go +++ b/cmds/observation-hub/main.go @@ -65,6 +65,14 @@ func main() { os.Exit(2) } + // Add additional modules. + observer, err := New(instance) + if err != nil { + fmt.Printf("error creating an instance: create observer module: %s\n", err) + os.Exit(2) + } + instance.AddModule(observer) + // Execute command line operation, if requested or available. switch { case !execCmdLine: diff --git a/cmds/observation-hub/observe.go b/cmds/observation-hub/observe.go index 8835858e8..354de200f 100644 --- a/cmds/observation-hub/observe.go +++ b/cmds/observation-hub/observe.go @@ -51,7 +51,8 @@ var ( errNoChanges = errors.New("no changes") reportingDelayFlag string - reportingDelay = 10 * time.Minute + reportingDelay = 5 * time.Minute + reportingMaxDelay = reportingDelay * 3 ) func init() { @@ -67,6 +68,7 @@ func prepObserver() error { } reportingDelay = duration } + reportingMaxDelay = reportingDelay * 3 return nil } @@ -81,8 +83,9 @@ type observedPin struct { previous *navigator.PinExport latest *navigator.PinExport - lastUpdate time.Time - lastUpdateReported bool + firstUpdate time.Time + lastUpdate time.Time + updateReported bool } type observedChange struct { @@ -149,19 +152,20 @@ func observerWorker(ctx *mgr.WorkerCtx) error { // Put all current pins in a map. observedPins := make(map[string]*observedPin) -query: +initialQuery: for { select { case r := <-q.Next: // Check if we are done. if r == nil { - break query + break initialQuery } // Add all pins to seen pins. if pin, ok := r.(*navigator.PinExport); ok { observedPins[pin.ID] = &observedPin{ - previous: pin, - latest: pin, + previous: pin, + latest: pin, + updateReported: true, } } else { log.Warningf("observer: received invalid pin export: %s", r) @@ -216,13 +220,18 @@ query: if ok { // Update previously observed Hub. existingObservedPin.latest = pin + if existingObservedPin.updateReported { + existingObservedPin.firstUpdate = time.Now() + } existingObservedPin.lastUpdate = time.Now() - existingObservedPin.lastUpdateReported = false + existingObservedPin.updateReported = false } else { // Add new Hub. observedPins[pin.ID] = &observedPin{ - latest: pin, - lastUpdate: time.Now(), + latest: pin, + firstUpdate: time.Now(), + lastUpdate: time.Now(), + updateReported: false, } } } else { @@ -242,15 +251,19 @@ query: } switch { - case observedPin.lastUpdateReported: + case observedPin.updateReported: // Change already reported. - case time.Since(observedPin.lastUpdate) < reportingDelay: + case time.Since(observedPin.lastUpdate) < reportingDelay && + time.Since(observedPin.firstUpdate) < reportingMaxDelay: // Only report changes if older than the configured delay. + // Up to a maximum delay. default: // Format and report. title, changes, err := formatPinChanges(observedPin.previous, observedPin.latest) if err != nil { - if !errors.Is(err, errNoChanges) { + if errors.Is(err, errNoChanges) { + log.Debugf("observer: no reportable changes found for %s", observedPin.latest.HumanName()) + } else { log.Warningf("observer: failed to format pin changes: %s", err) } } else { @@ -266,7 +279,7 @@ query: // Update observed pin. observedPin.previous = observedPin.latest - observedPin.lastUpdateReported = true + observedPin.updateReported = true } } } diff --git a/service/mgr/group.go b/service/mgr/group.go index 4a97873bd..035bcf05e 100644 --- a/service/mgr/group.go +++ b/service/mgr/group.go @@ -75,36 +75,43 @@ func NewGroup(modules ...Module) *Group { // Initialize groups modules. for _, m := range modules { - mgr := m.Manager() - - // Check module. - switch { - case m == nil: - // Skip nil values to allow for cleaner code. - continue - case reflect.ValueOf(m).IsNil(): - // If nil values are given via a struct, they are will be interfaces to a - // nil type. Ignore these too. - continue - case mgr == nil: - // Ignore modules that do not return a manager. - continue - case mgr.Name() == "": - // Force name if none is set. - // TODO: Unsafe if module is already logging, etc. - mgr.setName(makeModuleName(m)) - } - - // Add module to group. - g.modules = append(g.modules, &groupModule{ - module: m, - mgr: mgr, - }) + g.Add(m) } return g } +// Add validates the given module and adds it to the group, if all requirements are met. +// Not safe for concurrent use with any other method. +// All modules must be added before anything else is done with the group. +func (g *Group) Add(m Module) { + mgr := m.Manager() + + // Check module. + switch { + case m == nil: + // Skip nil values to allow for cleaner code. + return + case reflect.ValueOf(m).IsNil(): + // If nil values are given via a struct, they are will be interfaces to a + // nil type. Ignore these too. + return + case mgr == nil: + // Ignore modules that do not return a manager. + return + case mgr.Name() == "": + // Force name if none is set. + // TODO: Unsafe if module is already logging, etc. + mgr.setName(makeModuleName(m)) + } + + // Add module to group. + g.modules = append(g.modules, &groupModule{ + module: m, + mgr: mgr, + }) +} + // Start starts all modules in the group in the defined order. // If a module fails to start, itself and all previous modules // will be stopped in the reverse order. diff --git a/spn/instance.go b/spn/instance.go index 96866b52b..842ec7b5c 100644 --- a/spn/instance.go +++ b/spn/instance.go @@ -12,6 +12,7 @@ import ( "github.com/safing/portmaster/base/metrics" "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/base/rng" + "github.com/safing/portmaster/base/runtime" "github.com/safing/portmaster/service/core" "github.com/safing/portmaster/service/core/base" "github.com/safing/portmaster/service/intel/filterlists" @@ -44,6 +45,7 @@ type Instance struct { config *config.Config api *api.API metrics *metrics.Metrics + runtime *runtime.Runtime rng *rng.Rng core *core.Core @@ -95,6 +97,10 @@ func New() (*Instance, error) { if err != nil { return instance, fmt.Errorf("create metrics module: %w", err) } + instance.runtime, err = runtime.New(instance) + if err != nil { + return instance, fmt.Errorf("create runtime module: %w", err) + } instance.rng, err = rng.New(instance) if err != nil { return instance, fmt.Errorf("create rng module: %w", err) @@ -171,6 +177,7 @@ func New() (*Instance, error) { instance.config, instance.api, instance.metrics, + instance.runtime, instance.rng, instance.core, @@ -193,6 +200,12 @@ func New() (*Instance, error) { return instance, nil } +// AddModule validates the given module and adds it to the service group, if all requirements are met. +// All modules must be added before anything is done with the instance. +func (i *Instance) AddModule(m mgr.Module) { + i.serviceGroup.Add(m) +} + // SleepyModule is an interface for modules that can enter some sort of sleep mode. type SleepyModule interface { SetSleep(enabled bool) @@ -227,6 +240,11 @@ func (i *Instance) Metrics() *metrics.Metrics { return i.metrics } +// Runtime returns the runtime module. +func (i *Instance) Runtime() *runtime.Runtime { + return i.runtime +} + // Rng returns the rng module. func (i *Instance) Rng() *rng.Rng { return i.rng @@ -359,8 +377,29 @@ func (i *Instance) Stop() error { return i.serviceGroup.Stop() } +// RestartExitCode will instruct portmaster-start to restart the process immediately, potentially with a new version. +const RestartExitCode = 23 + // Shutdown asynchronously stops the instance. -func (i *Instance) Shutdown(exitCode int) { +func (i *Instance) Restart() { + // Send a restart event, give it 10ms extra to propagate. + i.core.EventRestart.Submit(struct{}{}) + time.Sleep(10 * time.Millisecond) + + i.shutdown(RestartExitCode) +} + +// Shutdown asynchronously stops the instance. +func (i *Instance) Shutdown() { + // Send a shutdown event, give it 10ms extra to propagate. + i.core.EventShutdown.Submit(struct{}{}) + time.Sleep(10 * time.Millisecond) + + i.shutdown(0) +} + +func (i *Instance) shutdown(exitCode int) { + // Set given exit code. i.exitCode.Store(int32(exitCode)) m := mgr.New("instance") @@ -386,11 +425,6 @@ func (i *Instance) Stopped() <-chan struct{} { return i.ctx.Done() } -// SetExitCode sets the exit code on the instance. -func (i *Instance) SetExitCode(exitCode int) { - i.exitCode.Store(int32(exitCode)) -} - // ExitCode returns the set exit code of the instance. func (i *Instance) ExitCode() int { return int(i.exitCode.Load()) diff --git a/spn/navigator/pin_export.go b/spn/navigator/pin_export.go index 422a074f9..7e03e2f7d 100644 --- a/spn/navigator/pin_export.go +++ b/spn/navigator/pin_export.go @@ -1,6 +1,7 @@ package navigator import ( + "fmt" "sync" "time" @@ -96,3 +97,25 @@ func (pin *Pin) Export() *PinExport { return export } + +// HumanName returns a human-readable version of a Hub's name. +// This name will likely consist of two parts: the given name and the ending of the ID to make it unique. +func (h *PinExport) HumanName() string { + if len(h.ID) < 8 { + return fmt.Sprintf("", h.ID) + } + + shortenedID := h.ID[len(h.ID)-8:len(h.ID)-4] + + "-" + + h.ID[len(h.ID)-4:] + + // Be more careful, as the Hub name is user input. + switch { + case h.Info.Name == "": + return fmt.Sprintf("", shortenedID) + case len(h.Info.Name) > 16: + return fmt.Sprintf("", h.Info.Name[:16], shortenedID) + default: + return fmt.Sprintf("", h.Info.Name, shortenedID) + } +} From b55c98615606ce4a2301e86cab433df1162d45c6 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 2 Aug 2024 16:54:04 +0200 Subject: [PATCH 50/56] Improve shutdown and restart handling --- service/core/api.go | 2 +- service/core/core.go | 2 +- service/instance.go | 28 ++++++++++++++++----- service/intel/geoip/init_test.go | 9 ++++++- service/netenv/init_test.go | 9 ++++++- service/profile/endpoints/endpoints_test.go | 9 ++++++- service/resolver/main_test.go | 9 ++++++- service/updates/module.go | 3 ++- service/updates/restart.go | 9 ++----- spn/captain/module.go | 1 - spn/hub/hub_test.go | 9 ++++++- spn/navigator/module_test.go | 9 ++++++- 12 files changed, 76 insertions(+), 23 deletions(-) diff --git a/service/core/api.go b/service/core/api.go index 5893c7066..1684a4eb7 100644 --- a/service/core/api.go +++ b/service/core/api.go @@ -113,7 +113,7 @@ func registerAPIEndpoints() error { func shutdown(_ *api.Request) (msg string, err error) { log.Warning("core: user requested shutdown via action") - module.instance.Shutdown(0) + module.instance.Shutdown() return "shutdown initiated", nil } diff --git a/service/core/core.go b/service/core/core.go index 9d6ab2305..33025565c 100644 --- a/service/core/core.go +++ b/service/core/core.go @@ -112,5 +112,5 @@ func New(instance instance) (*Core, error) { } type instance interface { - Shutdown(exitCode int) + Shutdown() } diff --git a/service/instance.go b/service/instance.go index f4b7d3b76..fd39a0ce4 100644 --- a/service/instance.go +++ b/service/instance.go @@ -577,8 +577,29 @@ func (i *Instance) Stop() error { return i.serviceGroup.Stop() } +// RestartExitCode will instruct portmaster-start to restart the process immediately, potentially with a new version. +const RestartExitCode = 23 + +// Shutdown asynchronously stops the instance. +func (i *Instance) Restart() { + // Send a restart event, give it 10ms extra to propagate. + i.core.EventRestart.Submit(struct{}{}) + time.Sleep(10 * time.Millisecond) + + i.shutdown(RestartExitCode) +} + // Shutdown asynchronously stops the instance. -func (i *Instance) Shutdown(exitCode int) { +func (i *Instance) Shutdown() { + // Send a shutdown event, give it 10ms extra to propagate. + i.core.EventShutdown.Submit(struct{}{}) + time.Sleep(10 * time.Millisecond) + + i.shutdown(0) +} + +func (i *Instance) shutdown(exitCode int) { + // Set given exit code. i.exitCode.Store(int32(exitCode)) m := mgr.New("instance") @@ -604,11 +625,6 @@ func (i *Instance) Stopped() <-chan struct{} { return i.ctx.Done() } -// SetExitCode sets the exit code on the instance. -func (i *Instance) SetExitCode(exitCode int) { - i.exitCode.Store(int32(exitCode)) -} - // ExitCode returns the set exit code of the instance. func (i *Instance) ExitCode() int { return int(i.exitCode.Load()) diff --git a/service/intel/geoip/init_test.go b/service/intel/geoip/init_test.go index 52c540324..b6d722dc4 100644 --- a/service/intel/geoip/init_test.go +++ b/service/intel/geoip/init_test.go @@ -8,6 +8,7 @@ import ( "github.com/safing/portmaster/base/api" "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/base/database/dbmodule" + "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/service/updates" ) @@ -32,11 +33,17 @@ func (stub *testInstance) Config() *config.Config { return stub.config } +func (stub *testInstance) Notifications() *notifications.Notifications { + return nil +} + func (stub *testInstance) Ready() bool { return true } -func (stub *testInstance) Shutdown(exitCode int) {} +func (stub *testInstance) Restart() {} + +func (stub *testInstance) Shutdown() {} func (stub *testInstance) SetCmdLineOperation(f func() error) {} diff --git a/service/netenv/init_test.go b/service/netenv/init_test.go index bce6c4925..17ef12403 100644 --- a/service/netenv/init_test.go +++ b/service/netenv/init_test.go @@ -8,6 +8,7 @@ import ( "github.com/safing/portmaster/base/api" "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/base/database/dbmodule" + "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/service/updates" ) @@ -32,11 +33,17 @@ func (stub *testInstance) Config() *config.Config { return stub.config } +func (stub *testInstance) Notifications() *notifications.Notifications { + return nil +} + func (stub *testInstance) Ready() bool { return true } -func (stub *testInstance) Shutdown(exitCode int) {} +func (stub *testInstance) Restart() {} + +func (stub *testInstance) Shutdown() {} func (stub *testInstance) SetCmdLineOperation(f func() error) {} diff --git a/service/profile/endpoints/endpoints_test.go b/service/profile/endpoints/endpoints_test.go index f4ed1b562..bbc81f6a6 100644 --- a/service/profile/endpoints/endpoints_test.go +++ b/service/profile/endpoints/endpoints_test.go @@ -13,6 +13,7 @@ import ( "github.com/safing/portmaster/base/api" "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/base/database/dbmodule" + "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/service/intel" "github.com/safing/portmaster/service/intel/geoip" "github.com/safing/portmaster/service/updates" @@ -38,11 +39,17 @@ func (stub *testInstance) Config() *config.Config { return stub.config } +func (stub *testInstance) Notifications() *notifications.Notifications { + return nil +} + func (stub *testInstance) Ready() bool { return true } -func (stub *testInstance) Shutdown(exitCode int) {} +func (stub *testInstance) Restart() {} + +func (stub *testInstance) Shutdown() {} func (stub *testInstance) SetCmdLineOperation(f func() error) {} diff --git a/service/resolver/main_test.go b/service/resolver/main_test.go index 44890ed05..99cb7b05b 100644 --- a/service/resolver/main_test.go +++ b/service/resolver/main_test.go @@ -8,6 +8,7 @@ import ( "github.com/safing/portmaster/base/api" "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/base/database/dbmodule" + "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/service/core/base" "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" @@ -43,11 +44,17 @@ func (stub *testInstance) NetEnv() *netenv.NetEnv { return stub.netenv } +func (stub *testInstance) Notifications() *notifications.Notifications { + return nil +} + func (stub *testInstance) Ready() bool { return true } -func (stub *testInstance) Shutdown(exitCode int) {} +func (stub *testInstance) Restart() {} + +func (stub *testInstance) Shutdown() {} func (stub *testInstance) SetCmdLineOperation(f func() error) {} diff --git a/service/updates/module.go b/service/updates/module.go index a1fdab159..c7cae1a92 100644 --- a/service/updates/module.go +++ b/service/updates/module.go @@ -76,6 +76,7 @@ func (u *Updates) Stop() error { type instance interface { API() *api.API Config() *config.Config - Shutdown(exitCode int) + Restart() + Shutdown() Notifications() *notifications.Notifications } diff --git a/service/updates/restart.go b/service/updates/restart.go index b2aa4e33b..729853ff1 100644 --- a/service/updates/restart.go +++ b/service/updates/restart.go @@ -12,11 +12,6 @@ import ( "github.com/safing/portmaster/service/mgr" ) -const ( - // RestartExitCode will instruct portmaster-start to restart the process immediately, potentially with a new version. - RestartExitCode = 23 -) - var ( // RebootOnRestart defines whether the whole system, not just the service, // should be restarted automatically when triggering a restart internally. @@ -114,9 +109,9 @@ func automaticRestart(w *mgr.WorkerCtx) error { // Set restart exit code. if !rebooting { - module.instance.Shutdown(RestartExitCode) + module.instance.Restart() } else { - module.instance.Shutdown(0) + module.instance.Shutdown() } } diff --git a/spn/captain/module.go b/spn/captain/module.go index 02742b02e..b25040624 100644 --- a/spn/captain/module.go +++ b/spn/captain/module.go @@ -243,5 +243,4 @@ type instance interface { Config() *config.Config Updates() *updates.Updates SPNGroup() *mgr.ExtendedGroup - Shutdown(exitCode int) } diff --git a/spn/hub/hub_test.go b/spn/hub/hub_test.go index 6f0cd60bb..391a61e7f 100644 --- a/spn/hub/hub_test.go +++ b/spn/hub/hub_test.go @@ -11,6 +11,7 @@ import ( "github.com/safing/portmaster/base/api" "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/base/database/dbmodule" + "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/service/core/base" "github.com/safing/portmaster/service/updates" ) @@ -35,6 +36,10 @@ func (stub *testInstance) Config() *config.Config { return stub.config } +func (stub *testInstance) Notifications() *notifications.Notifications { + return nil +} + func (stub *testInstance) Base() *base.Base { return stub.base } @@ -43,7 +48,9 @@ func (stub *testInstance) Ready() bool { return true } -func (stub *testInstance) Shutdown(exitCode int) {} +func (stub *testInstance) Restart() {} + +func (stub *testInstance) Shutdown() {} func (stub *testInstance) SetCmdLineOperation(f func() error) {} diff --git a/spn/navigator/module_test.go b/spn/navigator/module_test.go index 31ffdf4af..c0de91948 100644 --- a/spn/navigator/module_test.go +++ b/spn/navigator/module_test.go @@ -9,6 +9,7 @@ import ( "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/base/database/dbmodule" "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/service/core/base" "github.com/safing/portmaster/service/intel/geoip" "github.com/safing/portmaster/service/updates" @@ -39,11 +40,17 @@ func (stub *testInstance) Base() *base.Base { return stub.base } +func (stub *testInstance) Notifications() *notifications.Notifications { + return nil +} + func (stub *testInstance) Ready() bool { return true } -func (stub *testInstance) Shutdown(exitCode int) {} +func (stub *testInstance) Restart() {} + +func (stub *testInstance) Shutdown() {} func (stub *testInstance) SetCmdLineOperation(f func() error) {} From 52ea77e7f6b27f311f20344ac8e639f6999ab060 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 2 Aug 2024 16:54:20 +0200 Subject: [PATCH 51/56] Split up big connection.go source file --- service/network/connection.go | 216 ------------------------ service/network/connection_handler.go | 226 ++++++++++++++++++++++++++ 2 files changed, 226 insertions(+), 216 deletions(-) create mode 100644 service/network/connection_handler.go diff --git a/service/network/connection.go b/service/network/connection.go index e1d968b58..7ea96400d 100644 --- a/service/network/connection.go +++ b/service/network/connection.go @@ -15,7 +15,6 @@ import ( "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/notifications" "github.com/safing/portmaster/service/intel" - "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/netenv" "github.com/safing/portmaster/service/network/netutils" "github.com/safing/portmaster/service/network/packet" @@ -815,221 +814,6 @@ func (conn *Connection) delete() { } } -// SetFirewallHandler sets the firewall handler for this link, and starts a -// worker to handle the packets. -// The caller needs to hold a lock on the connection. -// Cannot be called with "nil" handler. Call StopFirewallHandler() instead. -func (conn *Connection) SetFirewallHandler(handler FirewallHandler) { - if handler == nil { - return - } - - // Initialize packet queue, if needed. - conn.pktQueueLock.Lock() - defer conn.pktQueueLock.Unlock() - if !conn.pktQueueActive { - conn.pktQueue = make(chan packet.Packet, 100) - conn.pktQueueActive = true - } - - // Start packet handler worker when new handler is set. - if conn.firewallHandler == nil { - module.mgr.Go("packet handler", conn.packetHandlerWorker) - } - - // Set new handler. - conn.firewallHandler = handler -} - -// UpdateFirewallHandler sets the firewall handler if it already set and the -// given handler is not nil. -// The caller needs to hold a lock on the connection. -func (conn *Connection) UpdateFirewallHandler(handler FirewallHandler) { - if handler != nil && conn.firewallHandler != nil { - conn.firewallHandler = handler - } -} - -// StopFirewallHandler unsets the firewall handler and stops the handler worker. -// The caller needs to hold a lock on the connection. -func (conn *Connection) StopFirewallHandler() { - conn.pktQueueLock.Lock() - defer conn.pktQueueLock.Unlock() - - // Unset the firewall handler to revert to the default handler. - conn.firewallHandler = nil - - // Signal the packet handler worker that it can stop. - if conn.pktQueueActive { - close(conn.pktQueue) - conn.pktQueueActive = false - } - - // Unset the packet queue so that it can be freed. - conn.pktQueue = nil -} - -// HandlePacket queues packet of Link for handling. -func (conn *Connection) HandlePacket(pkt packet.Packet) { - // Update last seen timestamp. - conn.lastSeen.Store(time.Now().Unix()) - - conn.pktQueueLock.Lock() - defer conn.pktQueueLock.Unlock() - - // execute handler or verdict - if conn.pktQueueActive { - select { - case conn.pktQueue <- pkt: - default: - log.Debugf( - "filter: dropping packet %s, as there is no space in the connection's handling queue", - pkt, - ) - _ = pkt.Drop() - } - } else { - // Run default handler. - defaultFirewallHandler(conn, pkt) - - // Record metrics. - packetHandlingHistogram.UpdateDuration(pkt.Info().SeenAt) - } -} - -var infoOnlyPacketsActive = abool.New() - -// packetHandlerWorker sequentially handles queued packets. -func (conn *Connection) packetHandlerWorker(ctx *mgr.WorkerCtx) error { - // Copy packet queue, so we can remove the reference from the connection - // when we stop the firewall handler. - var pktQueue chan packet.Packet - func() { - conn.pktQueueLock.Lock() - defer conn.pktQueueLock.Unlock() - pktQueue = conn.pktQueue - }() - - // pktSeq counts the seen packets. - var pktSeq int - - for { - select { - case pkt := <-pktQueue: - if pkt == nil { - return nil - } - pktSeq++ - - // Attempt to optimize packet handling order by handling info-only packets first. - switch { - case pktSeq > 1: - // Order correction is only for first packet. - - case pkt.InfoOnly(): - // Correct order only if first packet is not info-only. - - // We have observed a first packet that is info-only. - // Info-only packets seem to be active and working. - infoOnlyPacketsActive.Set() - - case pkt.ExpectInfo(): - // Packet itself tells us that we should expect an info-only packet. - fallthrough - - case infoOnlyPacketsActive.IsSet() && pkt.IsOutbound(): - // Info-only packets are active and the packet is outbound. - // The probability is high that we will also get an info-only packet for this connection. - // TODO: Do not do this for forwarded packets in the future. - - // DEBUG: - // log.Debugf("filter: waiting for info only packet in order to pull forward: %s", pkt) - select { - case infoPkt := <-pktQueue: - if infoPkt != nil { - // DEBUG: - // log.Debugf("filter: packet #%d [pulled forward] info=%v PID=%d packet: %s", pktSeq, infoPkt.InfoOnly(), infoPkt.Info().PID, pkt) - packetHandlerHandleConn(ctx.Ctx(), conn, infoPkt) - pktSeq++ - } - case <-time.After(1 * time.Millisecond): - } - } - - // DEBUG: - // switch { - // case pkt.Info().Inbound: - // log.Debugf("filter: packet #%d info=%v PID=%d packet: %s", pktSeq, pkt.InfoOnly(), pkt.Info().PID, pkt) - // case pktSeq == 1 && !pkt.InfoOnly(): - // log.Warningf("filter: packet #%d [should be info only!] info=%v PID=%d packet: %s", pktSeq, pkt.InfoOnly(), pkt.Info().PID, pkt) - // case pktSeq >= 2 && pkt.InfoOnly(): - // log.Errorf("filter: packet #%d [should not be info only!] info=%v PID=%d packet: %s", pktSeq, pkt.InfoOnly(), pkt.Info().PID, pkt) - // default: - // log.Debugf("filter: packet #%d info=%v PID=%d packet: %s", pktSeq, pkt.InfoOnly(), pkt.Info().PID, pkt) - // } - - packetHandlerHandleConn(ctx.Ctx(), conn, pkt) - - case <-ctx.Done(): - return nil - } - } -} - -func packetHandlerHandleConn(ctx context.Context, conn *Connection, pkt packet.Packet) { - conn.Lock() - defer conn.Unlock() - - // Check if we should use the default handler. - // The default handler is only for fully decided - // connections and just applying the verdict. - // There is no logging for these packets. - if conn.firewallHandler == nil { - // Run default handler. - defaultFirewallHandler(conn, pkt) - - // Record metrics. - packetHandlingHistogram.UpdateDuration(pkt.Info().SeenAt) - - return - } - - // Create tracing context. - // Add context tracer and set context on packet. - traceCtx, tracer := log.AddTracer(ctx) - if tracer != nil { - // The trace is submitted in `network.Connection.packetHandler()`. - tracer.Tracef("filter: handling packet: %s", pkt) - } - pkt.SetCtx(traceCtx) - - // Handle packet with set handler. - conn.firewallHandler(conn, pkt) - - // Record metrics. - packetHandlingHistogram.UpdateDuration(pkt.Info().SeenAt) - - // Log result and submit trace, when there are any changes. - if conn.saveWhenFinished { - switch { - case conn.DataIsComplete(): - tracer.Infof("filter: connection %s %s: %s", conn, conn.VerdictVerb(), conn.Reason.Msg) - case conn.Verdict != VerdictUndecided: - tracer.Debugf("filter: connection %s fast-tracked", pkt) - default: - tracer.Debugf("filter: gathered data on connection %s", conn) - } - // Submit trace logs. - tracer.Submit() - } - - // Push changes, if there are any. - if conn.saveWhenFinished { - conn.saveWhenFinished = false - conn.Save() - } -} - // GetActiveInspectors returns the list of active inspectors. func (conn *Connection) GetActiveInspectors() []bool { return conn.activeInspectors diff --git a/service/network/connection_handler.go b/service/network/connection_handler.go new file mode 100644 index 000000000..2d2613260 --- /dev/null +++ b/service/network/connection_handler.go @@ -0,0 +1,226 @@ +package network + +import ( + "context" + "time" + + "github.com/safing/portmaster/base/log" + "github.com/safing/portmaster/service/mgr" + "github.com/safing/portmaster/service/network/packet" + "github.com/tevino/abool" +) + +// SetFirewallHandler sets the firewall handler for this link, and starts a +// worker to handle the packets. +// The caller needs to hold a lock on the connection. +// Cannot be called with "nil" handler. Call StopFirewallHandler() instead. +func (conn *Connection) SetFirewallHandler(handler FirewallHandler) { + if handler == nil { + return + } + + // Initialize packet queue, if needed. + conn.pktQueueLock.Lock() + defer conn.pktQueueLock.Unlock() + if !conn.pktQueueActive { + conn.pktQueue = make(chan packet.Packet, 100) + conn.pktQueueActive = true + } + + // Start packet handler worker when new handler is set. + if conn.firewallHandler == nil { + module.mgr.Go("packet handler", conn.packetHandlerWorker) + } + + // Set new handler. + conn.firewallHandler = handler +} + +// UpdateFirewallHandler sets the firewall handler if it already set and the +// given handler is not nil. +// The caller needs to hold a lock on the connection. +func (conn *Connection) UpdateFirewallHandler(handler FirewallHandler) { + if handler != nil && conn.firewallHandler != nil { + conn.firewallHandler = handler + } +} + +// StopFirewallHandler unsets the firewall handler and stops the handler worker. +// The caller needs to hold a lock on the connection. +func (conn *Connection) StopFirewallHandler() { + conn.pktQueueLock.Lock() + defer conn.pktQueueLock.Unlock() + + // Unset the firewall handler to revert to the default handler. + conn.firewallHandler = nil + + // Signal the packet handler worker that it can stop. + if conn.pktQueueActive { + close(conn.pktQueue) + conn.pktQueueActive = false + } + + // Unset the packet queue so that it can be freed. + conn.pktQueue = nil +} + +// HandlePacket queues packet of Link for handling. +func (conn *Connection) HandlePacket(pkt packet.Packet) { + // Update last seen timestamp. + conn.lastSeen.Store(time.Now().Unix()) + + conn.pktQueueLock.Lock() + defer conn.pktQueueLock.Unlock() + + // execute handler or verdict + if conn.pktQueueActive { + select { + case conn.pktQueue <- pkt: + default: + log.Debugf( + "filter: dropping packet %s, as there is no space in the connection's handling queue", + pkt, + ) + _ = pkt.Drop() + } + } else { + // Run default handler. + defaultFirewallHandler(conn, pkt) + + // Record metrics. + packetHandlingHistogram.UpdateDuration(pkt.Info().SeenAt) + } +} + +var infoOnlyPacketsActive = abool.New() + +// packetHandlerWorker sequentially handles queued packets. +func (conn *Connection) packetHandlerWorker(ctx *mgr.WorkerCtx) error { + // Copy packet queue, so we can remove the reference from the connection + // when we stop the firewall handler. + var pktQueue chan packet.Packet + func() { + conn.pktQueueLock.Lock() + defer conn.pktQueueLock.Unlock() + pktQueue = conn.pktQueue + }() + + // pktSeq counts the seen packets. + var pktSeq int + + for { + select { + case pkt := <-pktQueue: + if pkt == nil { + return nil + } + pktSeq++ + + // Attempt to optimize packet handling order by handling info-only packets first. + switch { + case pktSeq > 1: + // Order correction is only for first packet. + + case pkt.InfoOnly(): + // Correct order only if first packet is not info-only. + + // We have observed a first packet that is info-only. + // Info-only packets seem to be active and working. + infoOnlyPacketsActive.Set() + + case pkt.ExpectInfo(): + // Packet itself tells us that we should expect an info-only packet. + fallthrough + + case infoOnlyPacketsActive.IsSet() && pkt.IsOutbound(): + // Info-only packets are active and the packet is outbound. + // The probability is high that we will also get an info-only packet for this connection. + // TODO: Do not do this for forwarded packets in the future. + + // DEBUG: + // log.Debugf("filter: waiting for info only packet in order to pull forward: %s", pkt) + select { + case infoPkt := <-pktQueue: + if infoPkt != nil { + // DEBUG: + // log.Debugf("filter: packet #%d [pulled forward] info=%v PID=%d packet: %s", pktSeq, infoPkt.InfoOnly(), infoPkt.Info().PID, pkt) + packetHandlerHandleConn(ctx.Ctx(), conn, infoPkt) + pktSeq++ + } + case <-time.After(1 * time.Millisecond): + } + } + + // DEBUG: + // switch { + // case pkt.Info().Inbound: + // log.Debugf("filter: packet #%d info=%v PID=%d packet: %s", pktSeq, pkt.InfoOnly(), pkt.Info().PID, pkt) + // case pktSeq == 1 && !pkt.InfoOnly(): + // log.Warningf("filter: packet #%d [should be info only!] info=%v PID=%d packet: %s", pktSeq, pkt.InfoOnly(), pkt.Info().PID, pkt) + // case pktSeq >= 2 && pkt.InfoOnly(): + // log.Errorf("filter: packet #%d [should not be info only!] info=%v PID=%d packet: %s", pktSeq, pkt.InfoOnly(), pkt.Info().PID, pkt) + // default: + // log.Debugf("filter: packet #%d info=%v PID=%d packet: %s", pktSeq, pkt.InfoOnly(), pkt.Info().PID, pkt) + // } + + packetHandlerHandleConn(ctx.Ctx(), conn, pkt) + + case <-ctx.Done(): + return nil + } + } +} + +func packetHandlerHandleConn(ctx context.Context, conn *Connection, pkt packet.Packet) { + conn.Lock() + defer conn.Unlock() + + // Check if we should use the default handler. + // The default handler is only for fully decided + // connections and just applying the verdict. + // There is no logging for these packets. + if conn.firewallHandler == nil { + // Run default handler. + defaultFirewallHandler(conn, pkt) + + // Record metrics. + packetHandlingHistogram.UpdateDuration(pkt.Info().SeenAt) + + return + } + + // Create tracing context. + // Add context tracer and set context on packet. + traceCtx, tracer := log.AddTracer(ctx) + if tracer != nil { + // The trace is submitted in `network.Connection.packetHandler()`. + tracer.Tracef("filter: handling packet: %s", pkt) + } + pkt.SetCtx(traceCtx) + + // Handle packet with set handler. + conn.firewallHandler(conn, pkt) + + // Record metrics. + packetHandlingHistogram.UpdateDuration(pkt.Info().SeenAt) + + // Log result and submit trace, when there are any changes. + if conn.saveWhenFinished { + switch { + case conn.DataIsComplete(): + tracer.Infof("filter: connection %s %s: %s", conn, conn.VerdictVerb(), conn.Reason.Msg) + case conn.Verdict != VerdictUndecided: + tracer.Debugf("filter: connection %s fast-tracked", pkt) + default: + tracer.Debugf("filter: gathered data on connection %s", conn) + } + // Submit trace logs. + tracer.Submit() + } + + // Push changes, if there are any. + if conn.saveWhenFinished { + conn.saveWhenFinished = false + conn.Save() + } +} From 0911bd2b846809b80b5238179e5cb5e82f901192 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 9 Aug 2024 13:24:38 +0200 Subject: [PATCH 52/56] Move varint and dsd packages to structures repo --- base/api/client/message.go | 2 +- base/api/database.go | 4 +- base/api/endpoints.go | 2 +- base/config/option.go | 2 +- base/container/container.go | 2 +- base/database/migration/migration.go | 2 +- base/database/query/query_test.go | 2 +- base/database/record/base.go | 2 +- base/database/record/meta-bench_test.go | 4 +- base/database/record/wrapper.go | 4 +- base/database/record/wrapper_test.go | 2 +- base/formats/dsd/compression.go | 103 --- base/formats/dsd/dsd.go | 160 ----- base/formats/dsd/dsd_test.go | 327 --------- base/formats/dsd/format.go | 73 -- base/formats/dsd/gencode_test.go | 824 ----------------------- base/formats/dsd/http.go | 178 ----- base/formats/dsd/http_test.go | 45 -- base/formats/dsd/interfaces.go | 9 - base/formats/dsd/tests.gencode | 23 - base/formats/varint/helpers.go | 48 -- base/formats/varint/varint.go | 97 --- base/formats/varint/varint_test.go | 141 ---- cmds/notifier/notify.go | 2 +- cmds/notifier/spn.go | 2 +- cmds/notifier/subsystems.go | 2 +- cmds/portmaster-start/logs.go | 2 +- service/intel/filterlists/decoder.go | 2 +- service/intel/filterlists/index.go | 2 +- service/netquery/manager.go | 2 +- service/netquery/runtime_query_runner.go | 2 +- service/profile/api.go | 2 +- service/sync/setting_single.go | 2 +- service/sync/util.go | 2 +- spn/access/client.go | 2 +- spn/access/storage.go | 2 +- spn/access/token/pblind.go | 2 +- spn/access/token/request_test.go | 2 +- spn/access/token/scramble.go | 2 +- spn/cabin/verification.go | 2 +- spn/captain/bootstrap.go | 2 +- spn/captain/op_gossip.go | 2 +- spn/captain/op_gossip_query.go | 2 +- spn/crew/op_connect.go | 2 +- spn/crew/op_ping.go | 2 +- spn/docks/bandwidth_test.go | 2 +- spn/docks/crane.go | 2 +- spn/docks/crane_init.go | 4 +- spn/docks/crane_verify.go | 2 +- spn/docks/op_capacity.go | 2 +- spn/docks/op_latency.go | 2 +- spn/docks/op_sync_state.go | 2 +- spn/docks/op_whoami.go | 2 +- spn/hub/update.go | 2 +- spn/hub/update_test.go | 2 +- spn/terminal/control_flow.go | 2 +- spn/terminal/errors.go | 2 +- spn/terminal/init.go | 4 +- spn/terminal/msgtypes.go | 2 +- spn/terminal/operation_counter.go | 4 +- 60 files changed, 54 insertions(+), 2082 deletions(-) delete mode 100644 base/formats/dsd/compression.go delete mode 100644 base/formats/dsd/dsd.go delete mode 100644 base/formats/dsd/dsd_test.go delete mode 100644 base/formats/dsd/format.go delete mode 100644 base/formats/dsd/gencode_test.go delete mode 100644 base/formats/dsd/http.go delete mode 100644 base/formats/dsd/http_test.go delete mode 100644 base/formats/dsd/interfaces.go delete mode 100644 base/formats/dsd/tests.gencode delete mode 100644 base/formats/varint/helpers.go delete mode 100644 base/formats/varint/varint.go delete mode 100644 base/formats/varint/varint_test.go diff --git a/base/api/client/message.go b/base/api/client/message.go index 85754e230..fc72ee2bc 100644 --- a/base/api/client/message.go +++ b/base/api/client/message.go @@ -6,8 +6,8 @@ import ( "github.com/tevino/abool" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/structures/container" + "github.com/safing/structures/dsd" ) // ErrMalformedMessage is returned when a malformed message was encountered. diff --git a/base/api/database.go b/base/api/database.go index 295559dcd..55edac08d 100644 --- a/base/api/database.go +++ b/base/api/database.go @@ -16,11 +16,11 @@ import ( "github.com/safing/portmaster/base/database/iterator" "github.com/safing/portmaster/base/database/query" "github.com/safing/portmaster/base/database/record" - "github.com/safing/portmaster/base/formats/dsd" - "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/mgr" "github.com/safing/structures/container" + "github.com/safing/structures/dsd" + "github.com/safing/structures/varint" ) const ( diff --git a/base/api/endpoints.go b/base/api/endpoints.go index 7ee14a4c0..6c8be79ec 100644 --- a/base/api/endpoints.go +++ b/base/api/endpoints.go @@ -14,8 +14,8 @@ import ( "github.com/gorilla/mux" "github.com/safing/portmaster/base/database/record" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/log" + "github.com/safing/structures/dsd" ) // Endpoint describes an API Endpoint. diff --git a/base/config/option.go b/base/config/option.go index 22b1c2022..a975dfac4 100644 --- a/base/config/option.go +++ b/base/config/option.go @@ -11,7 +11,7 @@ import ( "github.com/tidwall/sjson" "github.com/safing/portmaster/base/database/record" - "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/structures/dsd" ) // OptionType defines the value type of an option. diff --git a/base/container/container.go b/base/container/container.go index 775fc2053..d30e9615f 100644 --- a/base/container/container.go +++ b/base/container/container.go @@ -4,7 +4,7 @@ import ( "errors" "io" - "github.com/safing/portmaster/base/formats/varint" + "github.com/safing/structures/varint" ) // Container is []byte sclie on steroids, allowing for quick data appending, prepending and fetching. diff --git a/base/database/migration/migration.go b/base/database/migration/migration.go index 73b400839..e998c349a 100644 --- a/base/database/migration/migration.go +++ b/base/database/migration/migration.go @@ -12,8 +12,8 @@ import ( "github.com/safing/portmaster/base/database" "github.com/safing/portmaster/base/database/record" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/log" + "github.com/safing/structures/dsd" ) // MigrateFunc is called when a migration should be applied to the diff --git a/base/database/query/query_test.go b/base/database/query/query_test.go index 402dc6151..b174fe296 100644 --- a/base/database/query/query_test.go +++ b/base/database/query/query_test.go @@ -5,7 +5,7 @@ import ( "testing" "github.com/safing/portmaster/base/database/record" - "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/structures/dsd" ) // copied from https://github.com/tidwall/gjson/blob/master/gjson_test.go diff --git a/base/database/record/base.go b/base/database/record/base.go index 347255a1f..f26c165a5 100644 --- a/base/database/record/base.go +++ b/base/database/record/base.go @@ -4,9 +4,9 @@ import ( "errors" "github.com/safing/portmaster/base/database/accessor" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/log" "github.com/safing/structures/container" + "github.com/safing/structures/dsd" ) // TODO(ppacher): diff --git a/base/database/record/meta-bench_test.go b/base/database/record/meta-bench_test.go index bf1824db2..bfcf05173 100644 --- a/base/database/record/meta-bench_test.go +++ b/base/database/record/meta-bench_test.go @@ -21,9 +21,9 @@ import ( "testing" "time" - "github.com/safing/portmaster/base/formats/dsd" - "github.com/safing/portmaster/base/formats/varint" "github.com/safing/structures/container" + "github.com/safing/structures/dsd" + "github.com/safing/structures/varint" ) var testMeta = &Meta{ diff --git a/base/database/record/wrapper.go b/base/database/record/wrapper.go index 5204ffa1a..0f9a9e40c 100644 --- a/base/database/record/wrapper.go +++ b/base/database/record/wrapper.go @@ -6,9 +6,9 @@ import ( "sync" "github.com/safing/portmaster/base/database/accessor" - "github.com/safing/portmaster/base/formats/dsd" - "github.com/safing/portmaster/base/formats/varint" "github.com/safing/structures/container" + "github.com/safing/structures/dsd" + "github.com/safing/structures/varint" ) // Wrapper wraps raw data and implements the Record interface. diff --git a/base/database/record/wrapper_test.go b/base/database/record/wrapper_test.go index 5db3b01d3..2e923bd7c 100644 --- a/base/database/record/wrapper_test.go +++ b/base/database/record/wrapper_test.go @@ -4,7 +4,7 @@ import ( "bytes" "testing" - "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/structures/dsd" ) func TestWrapper(t *testing.T) { diff --git a/base/formats/dsd/compression.go b/base/formats/dsd/compression.go deleted file mode 100644 index d1baf2822..000000000 --- a/base/formats/dsd/compression.go +++ /dev/null @@ -1,103 +0,0 @@ -package dsd - -import ( - "bytes" - "compress/gzip" - "errors" - - "github.com/safing/portmaster/base/formats/varint" -) - -// DumpAndCompress stores the interface as a dsd formatted data structure and compresses the resulting data. -func DumpAndCompress(t interface{}, format uint8, compression uint8) ([]byte, error) { - // Check if compression format is valid. - compression, ok := ValidateCompressionFormat(compression) - if !ok { - return nil, ErrIncompatibleFormat - } - - // Dump the given data with the given format. - data, err := Dump(t, format) - if err != nil { - return nil, err - } - - // prepare writer - packetFormat := varint.Pack8(compression) - buf := bytes.NewBuffer(nil) - buf.Write(packetFormat) - - // compress - switch compression { - case GZIP: - // create gzip writer - gzipWriter, err := gzip.NewWriterLevel(buf, gzip.BestCompression) - if err != nil { - return nil, err - } - - // write data - n, err := gzipWriter.Write(data) - if err != nil { - return nil, err - } - if n != len(data) { - return nil, errors.New("failed to fully write to gzip compressor") - } - - // flush and write gzip footer - err = gzipWriter.Close() - if err != nil { - return nil, err - } - default: - return nil, ErrIncompatibleFormat - } - - return buf.Bytes(), nil -} - -// DecompressAndLoad decompresses the data using the specified compression format and then loads the resulting data blob into the interface. -func DecompressAndLoad(data []byte, compression uint8, t interface{}) (format uint8, err error) { - // Check if compression format is valid. - _, ok := ValidateCompressionFormat(compression) - if !ok { - return 0, ErrIncompatibleFormat - } - - // prepare reader - buf := bytes.NewBuffer(nil) - - // decompress - switch compression { - case GZIP: - // create gzip reader - gzipReader, err := gzip.NewReader(bytes.NewBuffer(data)) - if err != nil { - return 0, err - } - - // read uncompressed data - _, err = buf.ReadFrom(gzipReader) - if err != nil { - return 0, err - } - - // flush and verify gzip footer - err = gzipReader.Close() - if err != nil { - return 0, err - } - default: - return 0, ErrIncompatibleFormat - } - - // assign decompressed data - data = buf.Bytes() - - format, read, err := loadFormat(data) - if err != nil { - return 0, err - } - return format, LoadAsFormat(data[read:], format, t) -} diff --git a/base/formats/dsd/dsd.go b/base/formats/dsd/dsd.go deleted file mode 100644 index 76b8c4446..000000000 --- a/base/formats/dsd/dsd.go +++ /dev/null @@ -1,160 +0,0 @@ -package dsd - -// dynamic structured data -// check here for some benchmarks: https://github.com/alecthomas/go_serialization_benchmarks - -import ( - "encoding/json" - "errors" - "fmt" - "io" - - "github.com/fxamacker/cbor/v2" - "github.com/ghodss/yaml" - "github.com/vmihailenco/msgpack/v5" - - "github.com/safing/portmaster/base/formats/varint" - "github.com/safing/portmaster/base/utils" -) - -// Load loads an dsd structured data blob into the given interface. -func Load(data []byte, t interface{}) (format uint8, err error) { - format, read, err := loadFormat(data) - if err != nil { - return 0, err - } - - _, ok := ValidateSerializationFormat(format) - if ok { - return format, LoadAsFormat(data[read:], format, t) - } - return DecompressAndLoad(data[read:], format, t) -} - -// LoadAsFormat loads a data blob into the interface using the specified format. -func LoadAsFormat(data []byte, format uint8, t interface{}) (err error) { - switch format { - case RAW: - return ErrIsRaw - case JSON: - err = json.Unmarshal(data, t) - if err != nil { - return fmt.Errorf("dsd: failed to unpack json: %w, data: %s", err, utils.SafeFirst16Bytes(data)) - } - return nil - case YAML: - err = yaml.Unmarshal(data, t) - if err != nil { - return fmt.Errorf("dsd: failed to unpack yaml: %w, data: %s", err, utils.SafeFirst16Bytes(data)) - } - return nil - case CBOR: - err = cbor.Unmarshal(data, t) - if err != nil { - return fmt.Errorf("dsd: failed to unpack cbor: %w, data: %s", err, utils.SafeFirst16Bytes(data)) - } - return nil - case MsgPack: - err = msgpack.Unmarshal(data, t) - if err != nil { - return fmt.Errorf("dsd: failed to unpack msgpack: %w, data: %s", err, utils.SafeFirst16Bytes(data)) - } - return nil - case GenCode: - genCodeStruct, ok := t.(GenCodeCompatible) - if !ok { - return errors.New("dsd: gencode is not supported by the given data structure") - } - _, err = genCodeStruct.GenCodeUnmarshal(data) - if err != nil { - return fmt.Errorf("dsd: failed to unpack gencode: %w, data: %s", err, utils.SafeFirst16Bytes(data)) - } - return nil - default: - return ErrIncompatibleFormat - } -} - -func loadFormat(data []byte) (format uint8, read int, err error) { - format, read, err = varint.Unpack8(data) - if err != nil { - return 0, 0, err - } - if len(data) <= read { - return 0, 0, io.ErrUnexpectedEOF - } - - return format, read, nil -} - -// Dump stores the interface as a dsd formatted data structure. -func Dump(t interface{}, format uint8) ([]byte, error) { - return DumpIndent(t, format, "") -} - -// DumpIndent stores the interface as a dsd formatted data structure with indentation, if available. -func DumpIndent(t interface{}, format uint8, indent string) ([]byte, error) { - data, err := dumpWithoutIdentifier(t, format, indent) - if err != nil { - return nil, err - } - - // TODO: Find a better way to do this. - return append(varint.Pack8(format), data...), nil -} - -func dumpWithoutIdentifier(t interface{}, format uint8, indent string) ([]byte, error) { - format, ok := ValidateSerializationFormat(format) - if !ok { - return nil, ErrIncompatibleFormat - } - - var data []byte - var err error - switch format { - case RAW: - var ok bool - data, ok = t.([]byte) - if !ok { - return nil, ErrIncompatibleFormat - } - case JSON: - // TODO: use SetEscapeHTML(false) - if indent != "" { - data, err = json.MarshalIndent(t, "", indent) - } else { - data, err = json.Marshal(t) - } - if err != nil { - return nil, err - } - case YAML: - data, err = yaml.Marshal(t) - if err != nil { - return nil, err - } - case CBOR: - data, err = cbor.Marshal(t) - if err != nil { - return nil, err - } - case MsgPack: - data, err = msgpack.Marshal(t) - if err != nil { - return nil, err - } - case GenCode: - genCodeStruct, ok := t.(GenCodeCompatible) - if !ok { - return nil, errors.New("dsd: gencode is not supported by the given data structure") - } - data, err = genCodeStruct.GenCodeMarshal(nil) - if err != nil { - return nil, fmt.Errorf("dsd: failed to pack gencode struct: %w", err) - } - default: - return nil, ErrIncompatibleFormat - } - - return data, nil -} diff --git a/base/formats/dsd/dsd_test.go b/base/formats/dsd/dsd_test.go deleted file mode 100644 index 479f72711..000000000 --- a/base/formats/dsd/dsd_test.go +++ /dev/null @@ -1,327 +0,0 @@ -//nolint:maligned,gocyclo,gocognit -package dsd - -import ( - "math/big" - "reflect" - "testing" -) - -// SimpleTestStruct is used for testing. -type SimpleTestStruct struct { - S string - B byte -} - -type ComplexTestStruct struct { - I int - I8 int8 - I16 int16 - I32 int32 - I64 int64 - UI uint - UI8 uint8 - UI16 uint16 - UI32 uint32 - UI64 uint64 - BI *big.Int - S string - Sp *string - Sa []string - Sap *[]string - B byte - Bp *byte - Ba []byte - Bap *[]byte - M map[string]string - Mp *map[string]string -} - -type GenCodeTestStruct struct { - I8 int8 - I16 int16 - I32 int32 - I64 int64 - UI8 uint8 - UI16 uint16 - UI32 uint32 - UI64 uint64 - S string - Sp *string - Sa []string - Sap *[]string - B byte - Bp *byte - Ba []byte - Bap *[]byte -} - -var ( - simpleSubject = &SimpleTestStruct{ - "a", - 0x01, - } - - bString = "b" - bBytes byte = 0x02 - - complexSubject = &ComplexTestStruct{ - -1, - -2, - -3, - -4, - -5, - 1, - 2, - 3, - 4, - 5, - big.NewInt(6), - "a", - &bString, - []string{"c", "d", "e"}, - &[]string{"f", "g", "h"}, - 0x01, - &bBytes, - []byte{0x03, 0x04, 0x05}, - &[]byte{0x05, 0x06, 0x07}, - map[string]string{ - "a": "b", - "c": "d", - "e": "f", - }, - &map[string]string{ - "g": "h", - "i": "j", - "k": "l", - }, - } - - genCodeSubject = &GenCodeTestStruct{ - -2, - -3, - -4, - -5, - 2, - 3, - 4, - 5, - "a", - &bString, - []string{"c", "d", "e"}, - &[]string{"f", "g", "h"}, - 0x01, - &bBytes, - []byte{0x03, 0x04, 0x05}, - &[]byte{0x05, 0x06, 0x07}, - } -) - -func TestConversion(t *testing.T) { //nolint:maintidx - t.Parallel() - - compressionFormats := []uint8{AUTO, GZIP} - formats := []uint8{JSON, CBOR, MsgPack} - - for _, compression := range compressionFormats { - for _, format := range formats { - - // simple - var b []byte - var err error - if compression != AUTO { - b, err = DumpAndCompress(simpleSubject, format, compression) - } else { - b, err = Dump(simpleSubject, format) - } - if err != nil { - t.Fatalf("Dump error (simple struct): %s", err) - } - - si := &SimpleTestStruct{} - _, err = Load(b, si) - if err != nil { - t.Fatalf("Load error (simple struct): %s", err) - } - - if !reflect.DeepEqual(simpleSubject, si) { - t.Errorf("Load (simple struct): subject does not match loaded object") - t.Errorf("Encoded: %v", string(b)) - t.Errorf("Compared: %v == %v", simpleSubject, si) - } - - // complex - if compression != AUTO { - b, err = DumpAndCompress(complexSubject, format, compression) - } else { - b, err = Dump(complexSubject, format) - } - if err != nil { - t.Fatalf("Dump error (complex struct): %s", err) - } - - co := &ComplexTestStruct{} - _, err = Load(b, co) - if err != nil { - t.Fatalf("Load error (complex struct): %s", err) - } - - if complexSubject.I != co.I { - t.Errorf("Load (complex struct): struct.I is not equal (%v != %v)", complexSubject.I, co.I) - } - if complexSubject.I8 != co.I8 { - t.Errorf("Load (complex struct): struct.I8 is not equal (%v != %v)", complexSubject.I8, co.I8) - } - if complexSubject.I16 != co.I16 { - t.Errorf("Load (complex struct): struct.I16 is not equal (%v != %v)", complexSubject.I16, co.I16) - } - if complexSubject.I32 != co.I32 { - t.Errorf("Load (complex struct): struct.I32 is not equal (%v != %v)", complexSubject.I32, co.I32) - } - if complexSubject.I64 != co.I64 { - t.Errorf("Load (complex struct): struct.I64 is not equal (%v != %v)", complexSubject.I64, co.I64) - } - if complexSubject.UI != co.UI { - t.Errorf("Load (complex struct): struct.UI is not equal (%v != %v)", complexSubject.UI, co.UI) - } - if complexSubject.UI8 != co.UI8 { - t.Errorf("Load (complex struct): struct.UI8 is not equal (%v != %v)", complexSubject.UI8, co.UI8) - } - if complexSubject.UI16 != co.UI16 { - t.Errorf("Load (complex struct): struct.UI16 is not equal (%v != %v)", complexSubject.UI16, co.UI16) - } - if complexSubject.UI32 != co.UI32 { - t.Errorf("Load (complex struct): struct.UI32 is not equal (%v != %v)", complexSubject.UI32, co.UI32) - } - if complexSubject.UI64 != co.UI64 { - t.Errorf("Load (complex struct): struct.UI64 is not equal (%v != %v)", complexSubject.UI64, co.UI64) - } - if complexSubject.BI.Cmp(co.BI) != 0 { - t.Errorf("Load (complex struct): struct.BI is not equal (%v != %v)", complexSubject.BI, co.BI) - } - if complexSubject.S != co.S { - t.Errorf("Load (complex struct): struct.S is not equal (%v != %v)", complexSubject.S, co.S) - } - if !reflect.DeepEqual(complexSubject.Sp, co.Sp) { - t.Errorf("Load (complex struct): struct.Sp is not equal (%v != %v)", complexSubject.Sp, co.Sp) - } - if !reflect.DeepEqual(complexSubject.Sa, co.Sa) { - t.Errorf("Load (complex struct): struct.Sa is not equal (%v != %v)", complexSubject.Sa, co.Sa) - } - if !reflect.DeepEqual(complexSubject.Sap, co.Sap) { - t.Errorf("Load (complex struct): struct.Sap is not equal (%v != %v)", complexSubject.Sap, co.Sap) - } - if complexSubject.B != co.B { - t.Errorf("Load (complex struct): struct.B is not equal (%v != %v)", complexSubject.B, co.B) - } - if !reflect.DeepEqual(complexSubject.Bp, co.Bp) { - t.Errorf("Load (complex struct): struct.Bp is not equal (%v != %v)", complexSubject.Bp, co.Bp) - } - if !reflect.DeepEqual(complexSubject.Ba, co.Ba) { - t.Errorf("Load (complex struct): struct.Ba is not equal (%v != %v)", complexSubject.Ba, co.Ba) - } - if !reflect.DeepEqual(complexSubject.Bap, co.Bap) { - t.Errorf("Load (complex struct): struct.Bap is not equal (%v != %v)", complexSubject.Bap, co.Bap) - } - if !reflect.DeepEqual(complexSubject.M, co.M) { - t.Errorf("Load (complex struct): struct.M is not equal (%v != %v)", complexSubject.M, co.M) - } - if !reflect.DeepEqual(complexSubject.Mp, co.Mp) { - t.Errorf("Load (complex struct): struct.Mp is not equal (%v != %v)", complexSubject.Mp, co.Mp) - } - - } - - // test all formats - simplifiedFormatTesting := []uint8{JSON, CBOR, MsgPack, GenCode} - - for _, format := range simplifiedFormatTesting { - - // simple - var b []byte - var err error - if compression != AUTO { - b, err = DumpAndCompress(simpleSubject, format, compression) - } else { - b, err = Dump(simpleSubject, format) - } - if err != nil { - t.Fatalf("Dump error (simple struct): %s", err) - } - - si := &SimpleTestStruct{} - _, err = Load(b, si) - if err != nil { - t.Fatalf("Load error (simple struct): %s", err) - } - - if !reflect.DeepEqual(simpleSubject, si) { - t.Errorf("Load (simple struct): subject does not match loaded object") - t.Errorf("Encoded: %v", string(b)) - t.Errorf("Compared: %v == %v", simpleSubject, si) - } - - // complex - b, err = DumpAndCompress(genCodeSubject, format, compression) - if err != nil { - t.Fatalf("Dump error (complex struct): %s", err) - } - - co := &GenCodeTestStruct{} - _, err = Load(b, co) - if err != nil { - t.Fatalf("Load error (complex struct): %s", err) - } - - if genCodeSubject.I8 != co.I8 { - t.Errorf("Load (complex struct): struct.I8 is not equal (%v != %v)", genCodeSubject.I8, co.I8) - } - if genCodeSubject.I16 != co.I16 { - t.Errorf("Load (complex struct): struct.I16 is not equal (%v != %v)", genCodeSubject.I16, co.I16) - } - if genCodeSubject.I32 != co.I32 { - t.Errorf("Load (complex struct): struct.I32 is not equal (%v != %v)", genCodeSubject.I32, co.I32) - } - if genCodeSubject.I64 != co.I64 { - t.Errorf("Load (complex struct): struct.I64 is not equal (%v != %v)", genCodeSubject.I64, co.I64) - } - if genCodeSubject.UI8 != co.UI8 { - t.Errorf("Load (complex struct): struct.UI8 is not equal (%v != %v)", genCodeSubject.UI8, co.UI8) - } - if genCodeSubject.UI16 != co.UI16 { - t.Errorf("Load (complex struct): struct.UI16 is not equal (%v != %v)", genCodeSubject.UI16, co.UI16) - } - if genCodeSubject.UI32 != co.UI32 { - t.Errorf("Load (complex struct): struct.UI32 is not equal (%v != %v)", genCodeSubject.UI32, co.UI32) - } - if genCodeSubject.UI64 != co.UI64 { - t.Errorf("Load (complex struct): struct.UI64 is not equal (%v != %v)", genCodeSubject.UI64, co.UI64) - } - if genCodeSubject.S != co.S { - t.Errorf("Load (complex struct): struct.S is not equal (%v != %v)", genCodeSubject.S, co.S) - } - if !reflect.DeepEqual(genCodeSubject.Sp, co.Sp) { - t.Errorf("Load (complex struct): struct.Sp is not equal (%v != %v)", genCodeSubject.Sp, co.Sp) - } - if !reflect.DeepEqual(genCodeSubject.Sa, co.Sa) { - t.Errorf("Load (complex struct): struct.Sa is not equal (%v != %v)", genCodeSubject.Sa, co.Sa) - } - if !reflect.DeepEqual(genCodeSubject.Sap, co.Sap) { - t.Errorf("Load (complex struct): struct.Sap is not equal (%v != %v)", genCodeSubject.Sap, co.Sap) - } - if genCodeSubject.B != co.B { - t.Errorf("Load (complex struct): struct.B is not equal (%v != %v)", genCodeSubject.B, co.B) - } - if !reflect.DeepEqual(genCodeSubject.Bp, co.Bp) { - t.Errorf("Load (complex struct): struct.Bp is not equal (%v != %v)", genCodeSubject.Bp, co.Bp) - } - if !reflect.DeepEqual(genCodeSubject.Ba, co.Ba) { - t.Errorf("Load (complex struct): struct.Ba is not equal (%v != %v)", genCodeSubject.Ba, co.Ba) - } - if !reflect.DeepEqual(genCodeSubject.Bap, co.Bap) { - t.Errorf("Load (complex struct): struct.Bap is not equal (%v != %v)", genCodeSubject.Bap, co.Bap) - } - } - - } -} diff --git a/base/formats/dsd/format.go b/base/formats/dsd/format.go deleted file mode 100644 index c97950464..000000000 --- a/base/formats/dsd/format.go +++ /dev/null @@ -1,73 +0,0 @@ -package dsd - -import "errors" - -// Errors. -var ( - ErrIncompatibleFormat = errors.New("dsd: format is incompatible with operation") - ErrIsRaw = errors.New("dsd: given data is in raw format") - ErrUnknownFormat = errors.New("dsd: format is unknown") -) - -// Format types. -const ( - AUTO = 0 - - // Serialization types. - RAW = 1 - CBOR = 67 // C - GenCode = 71 // G - JSON = 74 // J - MsgPack = 77 // M - YAML = 89 // Y - - // Compression types. - GZIP = 90 // Z - - // Special types. - LIST = 76 // L -) - -// Default Formats. -var ( - DefaultSerializationFormat uint8 = JSON - DefaultCompressionFormat uint8 = GZIP -) - -// ValidateSerializationFormat validates if the format is for serialization, -// and returns the validated format as well as the result of the validation. -// If called on the AUTO format, it returns the default serialization format. -func ValidateSerializationFormat(format uint8) (validatedFormat uint8, ok bool) { - switch format { - case AUTO: - return DefaultSerializationFormat, true - case RAW: - return format, true - case CBOR: - return format, true - case GenCode: - return format, true - case JSON: - return format, true - case YAML: - return format, true - case MsgPack: - return format, true - default: - return 0, false - } -} - -// ValidateCompressionFormat validates if the format is for compression, -// and returns the validated format as well as the result of the validation. -// If called on the AUTO format, it returns the default compression format. -func ValidateCompressionFormat(format uint8) (validatedFormat uint8, ok bool) { - switch format { - case AUTO: - return DefaultCompressionFormat, true - case GZIP: - return format, true - default: - return 0, false - } -} diff --git a/base/formats/dsd/gencode_test.go b/base/formats/dsd/gencode_test.go deleted file mode 100644 index 2fbf18a00..000000000 --- a/base/formats/dsd/gencode_test.go +++ /dev/null @@ -1,824 +0,0 @@ -//nolint:nakedret,unconvert,gocognit,wastedassign,gofumpt -package dsd - -func (d *SimpleTestStruct) Size() (s uint64) { - - { - l := uint64(len(d.S)) - - { - - t := l - for t >= 0x80 { - t >>= 7 - s++ - } - s++ - - } - s += l - } - s++ - return -} - -func (d *SimpleTestStruct) GenCodeMarshal(buf []byte) ([]byte, error) { - size := d.Size() - { - if uint64(cap(buf)) >= size { - buf = buf[:size] - } else { - buf = make([]byte, size) - } - } - i := uint64(0) - - { - l := uint64(len(d.S)) - - { - - t := uint64(l) - - for t >= 0x80 { - buf[i+0] = byte(t) | 0x80 - t >>= 7 - i++ - } - buf[i+0] = byte(t) - i++ - - } - copy(buf[i+0:], d.S) - i += l - } - { - buf[i+0] = d.B - } - return buf[:i+1], nil -} - -func (d *SimpleTestStruct) GenCodeUnmarshal(buf []byte) (uint64, error) { - i := uint64(0) - - { - l := uint64(0) - - { - - bs := uint8(7) - t := uint64(buf[i+0] & 0x7F) - for buf[i+0]&0x80 == 0x80 { - i++ - t |= uint64(buf[i+0]&0x7F) << bs - bs += 7 - } - i++ - - l = t - - } - d.S = string(buf[i+0 : i+0+l]) - i += l - } - { - d.B = buf[i+0] - } - return i + 1, nil -} - -func (d *GenCodeTestStruct) Size() (s uint64) { - - { - l := uint64(len(d.S)) - - { - - t := l - for t >= 0x80 { - t >>= 7 - s++ - } - s++ - - } - s += l - } - { - if d.Sp != nil { - - { - l := uint64(len((*d.Sp))) - - { - - t := l - for t >= 0x80 { - t >>= 7 - s++ - } - s++ - - } - s += l - } - s += 0 - } - } - { - l := uint64(len(d.Sa)) - - { - - t := l - for t >= 0x80 { - t >>= 7 - s++ - } - s++ - - } - - for k0 := range d.Sa { - - { - l := uint64(len(d.Sa[k0])) - - { - - t := l - for t >= 0x80 { - t >>= 7 - s++ - } - s++ - - } - s += l - } - - } - - } - { - if d.Sap != nil { - - { - l := uint64(len((*d.Sap))) - - { - - t := l - for t >= 0x80 { - t >>= 7 - s++ - } - s++ - - } - - for k0 := range *d.Sap { - - { - l := uint64(len((*d.Sap)[k0])) - - { - - t := l - for t >= 0x80 { - t >>= 7 - s++ - } - s++ - - } - s += l - } - - } - - } - s += 0 - } - } - { - if d.Bp != nil { - - s++ - } - } - { - l := uint64(len(d.Ba)) - - { - - t := l - for t >= 0x80 { - t >>= 7 - s++ - } - s++ - - } - s += l - } - { - if d.Bap != nil { - - { - l := uint64(len((*d.Bap))) - - { - - t := l - for t >= 0x80 { - t >>= 7 - s++ - } - s++ - - } - s += l - } - s += 0 - } - } - s += 35 - return -} - -func (d *GenCodeTestStruct) GenCodeMarshal(buf []byte) ([]byte, error) { //nolint:maintidx - size := d.Size() - { - if uint64(cap(buf)) >= size { - buf = buf[:size] - } else { - buf = make([]byte, size) - } - } - i := uint64(0) - - { - - buf[0+0] = byte(d.I8 >> 0) - - } - { - - buf[0+1] = byte(d.I16 >> 0) - - buf[1+1] = byte(d.I16 >> 8) - - } - { - - buf[0+3] = byte(d.I32 >> 0) - - buf[1+3] = byte(d.I32 >> 8) - - buf[2+3] = byte(d.I32 >> 16) - - buf[3+3] = byte(d.I32 >> 24) - - } - { - - buf[0+7] = byte(d.I64 >> 0) - - buf[1+7] = byte(d.I64 >> 8) - - buf[2+7] = byte(d.I64 >> 16) - - buf[3+7] = byte(d.I64 >> 24) - - buf[4+7] = byte(d.I64 >> 32) - - buf[5+7] = byte(d.I64 >> 40) - - buf[6+7] = byte(d.I64 >> 48) - - buf[7+7] = byte(d.I64 >> 56) - - } - { - - buf[0+15] = byte(d.UI8 >> 0) - - } - { - - buf[0+16] = byte(d.UI16 >> 0) - - buf[1+16] = byte(d.UI16 >> 8) - - } - { - - buf[0+18] = byte(d.UI32 >> 0) - - buf[1+18] = byte(d.UI32 >> 8) - - buf[2+18] = byte(d.UI32 >> 16) - - buf[3+18] = byte(d.UI32 >> 24) - - } - { - - buf[0+22] = byte(d.UI64 >> 0) - - buf[1+22] = byte(d.UI64 >> 8) - - buf[2+22] = byte(d.UI64 >> 16) - - buf[3+22] = byte(d.UI64 >> 24) - - buf[4+22] = byte(d.UI64 >> 32) - - buf[5+22] = byte(d.UI64 >> 40) - - buf[6+22] = byte(d.UI64 >> 48) - - buf[7+22] = byte(d.UI64 >> 56) - - } - { - l := uint64(len(d.S)) - - { - - t := uint64(l) - - for t >= 0x80 { - buf[i+30] = byte(t) | 0x80 - t >>= 7 - i++ - } - buf[i+30] = byte(t) - i++ - - } - copy(buf[i+30:], d.S) - i += l - } - { - if d.Sp == nil { - buf[i+30] = 0 - } else { - buf[i+30] = 1 - - { - l := uint64(len((*d.Sp))) - - { - - t := uint64(l) - - for t >= 0x80 { - buf[i+31] = byte(t) | 0x80 - t >>= 7 - i++ - } - buf[i+31] = byte(t) - i++ - - } - copy(buf[i+31:], (*d.Sp)) - i += l - } - i += 0 - } - } - { - l := uint64(len(d.Sa)) - - { - - t := uint64(l) - - for t >= 0x80 { - buf[i+31] = byte(t) | 0x80 - t >>= 7 - i++ - } - buf[i+31] = byte(t) - i++ - - } - for k0 := range d.Sa { - - { - l := uint64(len(d.Sa[k0])) - - { - - t := uint64(l) - - for t >= 0x80 { - buf[i+31] = byte(t) | 0x80 - t >>= 7 - i++ - } - buf[i+31] = byte(t) - i++ - - } - copy(buf[i+31:], d.Sa[k0]) - i += l - } - - } - } - { - if d.Sap == nil { - buf[i+31] = 0 - } else { - buf[i+31] = 1 - - { - l := uint64(len((*d.Sap))) - - { - - t := uint64(l) - - for t >= 0x80 { - buf[i+32] = byte(t) | 0x80 - t >>= 7 - i++ - } - buf[i+32] = byte(t) - i++ - - } - for k0 := range *d.Sap { - - { - l := uint64(len((*d.Sap)[k0])) - - { - - t := uint64(l) - - for t >= 0x80 { - buf[i+32] = byte(t) | 0x80 - t >>= 7 - i++ - } - buf[i+32] = byte(t) - i++ - - } - copy(buf[i+32:], (*d.Sap)[k0]) - i += l - } - - } - } - i += 0 - } - } - { - buf[i+32] = d.B - } - { - if d.Bp == nil { - buf[i+33] = 0 - } else { - buf[i+33] = 1 - - { - buf[i+34] = (*d.Bp) - } - i++ - } - } - { - l := uint64(len(d.Ba)) - - { - - t := uint64(l) - - for t >= 0x80 { - buf[i+34] = byte(t) | 0x80 - t >>= 7 - i++ - } - buf[i+34] = byte(t) - i++ - - } - copy(buf[i+34:], d.Ba) - i += l - } - { - if d.Bap == nil { - buf[i+34] = 0 - } else { - buf[i+34] = 1 - - { - l := uint64(len((*d.Bap))) - - { - - t := uint64(l) - - for t >= 0x80 { - buf[i+35] = byte(t) | 0x80 - t >>= 7 - i++ - } - buf[i+35] = byte(t) - i++ - - } - copy(buf[i+35:], (*d.Bap)) - i += l - } - i += 0 - } - } - return buf[:i+35], nil -} - -func (d *GenCodeTestStruct) GenCodeUnmarshal(buf []byte) (uint64, error) { //nolint:maintidx - i := uint64(0) - - { - - d.I8 = 0 | (int8(buf[i+0+0]) << 0) - - } - { - - d.I16 = 0 | (int16(buf[i+0+1]) << 0) | (int16(buf[i+1+1]) << 8) - - } - { - - d.I32 = 0 | (int32(buf[i+0+3]) << 0) | (int32(buf[i+1+3]) << 8) | (int32(buf[i+2+3]) << 16) | (int32(buf[i+3+3]) << 24) - - } - { - - d.I64 = 0 | (int64(buf[i+0+7]) << 0) | (int64(buf[i+1+7]) << 8) | (int64(buf[i+2+7]) << 16) | (int64(buf[i+3+7]) << 24) | (int64(buf[i+4+7]) << 32) | (int64(buf[i+5+7]) << 40) | (int64(buf[i+6+7]) << 48) | (int64(buf[i+7+7]) << 56) - - } - { - - d.UI8 = 0 | (uint8(buf[i+0+15]) << 0) - - } - { - - d.UI16 = 0 | (uint16(buf[i+0+16]) << 0) | (uint16(buf[i+1+16]) << 8) - - } - { - - d.UI32 = 0 | (uint32(buf[i+0+18]) << 0) | (uint32(buf[i+1+18]) << 8) | (uint32(buf[i+2+18]) << 16) | (uint32(buf[i+3+18]) << 24) - - } - { - - d.UI64 = 0 | (uint64(buf[i+0+22]) << 0) | (uint64(buf[i+1+22]) << 8) | (uint64(buf[i+2+22]) << 16) | (uint64(buf[i+3+22]) << 24) | (uint64(buf[i+4+22]) << 32) | (uint64(buf[i+5+22]) << 40) | (uint64(buf[i+6+22]) << 48) | (uint64(buf[i+7+22]) << 56) - - } - { - l := uint64(0) - - { - - bs := uint8(7) - t := uint64(buf[i+30] & 0x7F) - for buf[i+30]&0x80 == 0x80 { - i++ - t |= uint64(buf[i+30]&0x7F) << bs - bs += 7 - } - i++ - - l = t - - } - d.S = string(buf[i+30 : i+30+l]) - i += l - } - { - if buf[i+30] == 1 { - if d.Sp == nil { - d.Sp = new(string) - } - - { - l := uint64(0) - - { - - bs := uint8(7) - t := uint64(buf[i+31] & 0x7F) - for buf[i+31]&0x80 == 0x80 { - i++ - t |= uint64(buf[i+31]&0x7F) << bs - bs += 7 - } - i++ - - l = t - - } - (*d.Sp) = string(buf[i+31 : i+31+l]) - i += l - } - i += 0 - } else { - d.Sp = nil - } - } - { - l := uint64(0) - - { - - bs := uint8(7) - t := uint64(buf[i+31] & 0x7F) - for buf[i+31]&0x80 == 0x80 { - i++ - t |= uint64(buf[i+31]&0x7F) << bs - bs += 7 - } - i++ - - l = t - - } - if uint64(cap(d.Sa)) >= l { - d.Sa = d.Sa[:l] - } else { - d.Sa = make([]string, l) - } - for k0 := range d.Sa { - - { - l := uint64(0) - - { - - bs := uint8(7) - t := uint64(buf[i+31] & 0x7F) - for buf[i+31]&0x80 == 0x80 { - i++ - t |= uint64(buf[i+31]&0x7F) << bs - bs += 7 - } - i++ - - l = t - - } - d.Sa[k0] = string(buf[i+31 : i+31+l]) - i += l - } - - } - } - { - if buf[i+31] == 1 { - if d.Sap == nil { - d.Sap = new([]string) - } - - { - l := uint64(0) - - { - - bs := uint8(7) - t := uint64(buf[i+32] & 0x7F) - for buf[i+32]&0x80 == 0x80 { - i++ - t |= uint64(buf[i+32]&0x7F) << bs - bs += 7 - } - i++ - - l = t - - } - if uint64(cap((*d.Sap))) >= l { - (*d.Sap) = (*d.Sap)[:l] - } else { - (*d.Sap) = make([]string, l) - } - for k0 := range *d.Sap { - - { - l := uint64(0) - - { - - bs := uint8(7) - t := uint64(buf[i+32] & 0x7F) - for buf[i+32]&0x80 == 0x80 { - i++ - t |= uint64(buf[i+32]&0x7F) << bs - bs += 7 - } - i++ - - l = t - - } - (*d.Sap)[k0] = string(buf[i+32 : i+32+l]) - i += l - } - - } - } - i += 0 - } else { - d.Sap = nil - } - } - { - d.B = buf[i+32] - } - { - if buf[i+33] == 1 { - if d.Bp == nil { - d.Bp = new(byte) - } - - { - (*d.Bp) = buf[i+34] - } - i++ - } else { - d.Bp = nil - } - } - { - l := uint64(0) - - { - - bs := uint8(7) - t := uint64(buf[i+34] & 0x7F) - for buf[i+34]&0x80 == 0x80 { - i++ - t |= uint64(buf[i+34]&0x7F) << bs - bs += 7 - } - i++ - - l = t - - } - if uint64(cap(d.Ba)) >= l { - d.Ba = d.Ba[:l] - } else { - d.Ba = make([]byte, l) - } - copy(d.Ba, buf[i+34:]) - i += l - } - { - if buf[i+34] == 1 { - if d.Bap == nil { - d.Bap = new([]byte) - } - - { - l := uint64(0) - - { - - bs := uint8(7) - t := uint64(buf[i+35] & 0x7F) - for buf[i+35]&0x80 == 0x80 { - i++ - t |= uint64(buf[i+35]&0x7F) << bs - bs += 7 - } - i++ - - l = t - - } - if uint64(cap((*d.Bap))) >= l { - (*d.Bap) = (*d.Bap)[:l] - } else { - (*d.Bap) = make([]byte, l) - } - copy((*d.Bap), buf[i+35:]) - i += l - } - i += 0 - } else { - d.Bap = nil - } - } - return i + 35, nil -} diff --git a/base/formats/dsd/http.go b/base/formats/dsd/http.go deleted file mode 100644 index 85aab163a..000000000 --- a/base/formats/dsd/http.go +++ /dev/null @@ -1,178 +0,0 @@ -package dsd - -import ( - "bytes" - "errors" - "fmt" - "io" - "net/http" - "strings" -) - -// HTTP Related Errors. -var ( - ErrMissingBody = errors.New("dsd: missing http body") - ErrMissingContentType = errors.New("dsd: missing http content type") -) - -const ( - httpHeaderContentType = "Content-Type" -) - -// LoadFromHTTPRequest loads the data from the body into the given interface. -func LoadFromHTTPRequest(r *http.Request, t interface{}) (format uint8, err error) { - return loadFromHTTP(r.Body, r.Header.Get(httpHeaderContentType), t) -} - -// LoadFromHTTPResponse loads the data from the body into the given interface. -// Closing the body is left to the caller. -func LoadFromHTTPResponse(resp *http.Response, t interface{}) (format uint8, err error) { - return loadFromHTTP(resp.Body, resp.Header.Get(httpHeaderContentType), t) -} - -func loadFromHTTP(body io.Reader, mimeType string, t interface{}) (format uint8, err error) { - // Read full body. - data, err := io.ReadAll(body) - if err != nil { - return 0, fmt.Errorf("dsd: failed to read http body: %w", err) - } - - // Load depending on mime type. - return MimeLoad(data, mimeType, t) -} - -// RequestHTTPResponseFormat sets the Accept header to the given format. -func RequestHTTPResponseFormat(r *http.Request, format uint8) (mimeType string, err error) { - // Get mime type. - mimeType, ok := FormatToMimeType[format] - if !ok { - return "", ErrIncompatibleFormat - } - - // Request response format. - r.Header.Set("Accept", mimeType) - - return mimeType, nil -} - -// DumpToHTTPRequest dumps the given data to the HTTP request using the given -// format. It also sets the Accept header to the same format. -func DumpToHTTPRequest(r *http.Request, t interface{}, format uint8) error { - // Get mime type and set request format. - mimeType, err := RequestHTTPResponseFormat(r, format) - if err != nil { - return err - } - - // Serialize data. - data, err := dumpWithoutIdentifier(t, format, "") - if err != nil { - return fmt.Errorf("dsd: failed to serialize: %w", err) - } - - // Add data to request. - r.Header.Set("Content-Type", mimeType) - r.Body = io.NopCloser(bytes.NewReader(data)) - - return nil -} - -// DumpToHTTPResponse dumpts the given data to the HTTP response, using the -// format defined in the request's Accept header. -func DumpToHTTPResponse(w http.ResponseWriter, r *http.Request, t interface{}) error { - // Serialize data based on accept header. - data, mimeType, _, err := MimeDump(t, r.Header.Get("Accept")) - if err != nil { - return fmt.Errorf("dsd: failed to serialize: %w", err) - } - - // Write data to response - w.Header().Set("Content-Type", mimeType) - _, err = w.Write(data) - if err != nil { - return fmt.Errorf("dsd: failed to write response: %w", err) - } - return nil -} - -// MimeLoad loads the given data into the interface based on the given mime type accept header. -func MimeLoad(data []byte, accept string, t interface{}) (format uint8, err error) { - // Find format. - format = FormatFromAccept(accept) - if format == 0 { - return 0, ErrIncompatibleFormat - } - - // Load data. - err = LoadAsFormat(data, format, t) - return format, err -} - -// MimeDump dumps the given interface based on the given mime type accept header. -func MimeDump(t any, accept string) (data []byte, mimeType string, format uint8, err error) { - // Find format. - format = FormatFromAccept(accept) - if format == AUTO { - return nil, "", 0, ErrIncompatibleFormat - } - - // Serialize and return. - data, err = dumpWithoutIdentifier(t, format, "") - return data, mimeType, format, err -} - -// FormatFromAccept returns the format for the given accept definition. -// The accept parameter matches the format of the HTTP Accept header. -// Special cases, in this order: -// - If accept is an empty string: returns default serialization format. -// - If accept contains no supported format, but a wildcard: returns default serialization format. -// - If accept contains no supported format, and no wildcard: returns AUTO format. -func FormatFromAccept(accept string) (format uint8) { - if accept == "" { - return DefaultSerializationFormat - } - - var foundWildcard bool - for _, mimeType := range strings.Split(accept, ",") { - // Clean mime type. - mimeType = strings.TrimSpace(mimeType) - mimeType, _, _ = strings.Cut(mimeType, ";") - if strings.Contains(mimeType, "/") { - _, mimeType, _ = strings.Cut(mimeType, "/") - } - mimeType = strings.ToLower(mimeType) - - // Check if mime type is supported. - format, ok := MimeTypeToFormat[mimeType] - if ok { - return format - } - - // Return default mime type as fallback if any mimetype is okay. - if mimeType == "*" { - foundWildcard = true - } - } - - if foundWildcard { - return DefaultSerializationFormat - } - return AUTO -} - -// Format and MimeType mappings. -var ( - FormatToMimeType = map[uint8]string{ - CBOR: "application/cbor", - JSON: "application/json", - MsgPack: "application/msgpack", - YAML: "application/yaml", - } - MimeTypeToFormat = map[string]uint8{ - "cbor": CBOR, - "json": JSON, - "msgpack": MsgPack, - "yaml": YAML, - "yml": YAML, - } -) diff --git a/base/formats/dsd/http_test.go b/base/formats/dsd/http_test.go deleted file mode 100644 index 32651ac84..000000000 --- a/base/formats/dsd/http_test.go +++ /dev/null @@ -1,45 +0,0 @@ -package dsd - -import ( - "mime" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestMimeTypes(t *testing.T) { - t.Parallel() - - // Test static maps. - for _, mimeType := range FormatToMimeType { - cleaned, _, err := mime.ParseMediaType(mimeType) - assert.NoError(t, err, "mime type must be parse-able") - assert.Equal(t, mimeType, cleaned, "mime type should be clean in map already") - } - for mimeType := range MimeTypeToFormat { - cleaned, _, err := mime.ParseMediaType(mimeType) - assert.NoError(t, err, "mime type must be parse-able") - assert.Equal(t, mimeType, cleaned, "mime type should be clean in map already") - } - - // Test assumptions. - for accept, format := range map[string]uint8{ - "application/json, image/webp": JSON, - "image/webp, application/json": JSON, - "application/json;q=0.9, image/webp": JSON, - "*": DefaultSerializationFormat, - "*/*": DefaultSerializationFormat, - "text/yAMl": YAML, - " * , yaml ": YAML, - "yaml;charset ,*": YAML, - "xml,*": DefaultSerializationFormat, - "text/xml, text/other": AUTO, - "text/*": DefaultSerializationFormat, - "yaml ;charset": AUTO, // Invalid mimetype format. - "": DefaultSerializationFormat, - "x": AUTO, - } { - derivedFormat := FormatFromAccept(accept) - assert.Equal(t, format, derivedFormat, "assumption for %q should hold", accept) - } -} diff --git a/base/formats/dsd/interfaces.go b/base/formats/dsd/interfaces.go deleted file mode 100644 index cae605241..000000000 --- a/base/formats/dsd/interfaces.go +++ /dev/null @@ -1,9 +0,0 @@ -package dsd - -// GenCodeCompatible is an interface to identify and use gencode compatible structs. -type GenCodeCompatible interface { - // GenCodeMarshal gencode marshalls the struct into the given byte array, or a new one if its too small. - GenCodeMarshal(buf []byte) ([]byte, error) - // GenCodeUnmarshal gencode unmarshalls the struct and returns the bytes read. - GenCodeUnmarshal(buf []byte) (uint64, error) -} diff --git a/base/formats/dsd/tests.gencode b/base/formats/dsd/tests.gencode deleted file mode 100644 index bc29f5d36..000000000 --- a/base/formats/dsd/tests.gencode +++ /dev/null @@ -1,23 +0,0 @@ -struct SimpleTestStruct { - S string - B byte -} - -struct GenCodeTestStructure { - I8 int8 - I16 int16 - I32 int32 - I64 int64 - UI8 uint8 - UI16 uint16 - UI32 uint32 - UI64 uint64 - S string - Sp *string - Sa []string - Sap *[]string - B byte - Bp *byte - Ba []byte - Bap *[]byte -} diff --git a/base/formats/varint/helpers.go b/base/formats/varint/helpers.go deleted file mode 100644 index 0aa2c8154..000000000 --- a/base/formats/varint/helpers.go +++ /dev/null @@ -1,48 +0,0 @@ -package varint - -import "errors" - -// PrependLength prepends the varint encoded length of the byte slice to itself. -func PrependLength(data []byte) []byte { - return append(Pack64(uint64(len(data))), data...) -} - -// GetNextBlock extract the integer from the beginning of the given byte slice and returns the remaining bytes, the extracted integer, and whether there was an error. -func GetNextBlock(data []byte) ([]byte, int, error) { - l, n, err := Unpack64(data) - if err != nil { - return nil, 0, err - } - length := int(l) - totalLength := length + n - if totalLength > len(data) { - return nil, 0, errors.New("varint: not enough data for given block length") - } - return data[n:totalLength], totalLength, nil -} - -// EncodedSize returns the size required to varint-encode an uint. -func EncodedSize(n uint64) (size int) { - switch { - case n < 1<<7: // < 128 - return 1 - case n < 1<<14: // < 16384 - return 2 - case n < 1<<21: // < 2097152 - return 3 - case n < 1<<28: // < 268435456 - return 4 - case n < 1<<35: // < 34359738368 - return 5 - case n < 1<<42: // < 4398046511104 - return 6 - case n < 1<<49: // < 562949953421312 - return 7 - case n < 1<<56: // < 72057594037927936 - return 8 - case n < 1<<63: // < 9223372036854775808 - return 9 - default: - return 10 - } -} diff --git a/base/formats/varint/varint.go b/base/formats/varint/varint.go deleted file mode 100644 index 05880e09d..000000000 --- a/base/formats/varint/varint.go +++ /dev/null @@ -1,97 +0,0 @@ -package varint - -import ( - "encoding/binary" - "errors" -) - -// ErrBufTooSmall is returned when there is not enough data for parsing a varint. -var ErrBufTooSmall = errors.New("varint: buf too small") - -// Pack8 packs a uint8 into a VarInt. -func Pack8(n uint8) []byte { - if n < 128 { - return []byte{n} - } - return []byte{n, 0x01} -} - -// Pack16 packs a uint16 into a VarInt. -func Pack16(n uint16) []byte { - buf := make([]byte, 3) - w := binary.PutUvarint(buf, uint64(n)) - return buf[:w] -} - -// Pack32 packs a uint32 into a VarInt. -func Pack32(n uint32) []byte { - buf := make([]byte, 5) - w := binary.PutUvarint(buf, uint64(n)) - return buf[:w] -} - -// Pack64 packs a uint64 into a VarInt. -func Pack64(n uint64) []byte { - buf := make([]byte, 10) - w := binary.PutUvarint(buf, n) - return buf[:w] -} - -// Unpack8 unpacks a VarInt into a uint8. It returns the extracted int, how many bytes were used and an error. -func Unpack8(blob []byte) (uint8, int, error) { - if len(blob) < 1 { - return 0, 0, ErrBufTooSmall - } - if blob[0] < 128 { - return blob[0], 1, nil - } - if len(blob) < 2 { - return 0, 0, ErrBufTooSmall - } - if blob[1] != 0x01 { - return 0, 0, errors.New("varint: encoded integer greater than 255 (uint8)") - } - return blob[0], 1, nil -} - -// Unpack16 unpacks a VarInt into a uint16. It returns the extracted int, how many bytes were used and an error. -func Unpack16(blob []byte) (uint16, int, error) { - n, r := binary.Uvarint(blob) - if r == 0 { - return 0, 0, ErrBufTooSmall - } - if r < 0 { - return 0, 0, errors.New("varint: encoded integer greater than 18446744073709551615 (uint64)") - } - if n > 65535 { - return 0, 0, errors.New("varint: encoded integer greater than 65535 (uint16)") - } - return uint16(n), r, nil -} - -// Unpack32 unpacks a VarInt into a uint32. It returns the extracted int, how many bytes were used and an error. -func Unpack32(blob []byte) (uint32, int, error) { - n, r := binary.Uvarint(blob) - if r == 0 { - return 0, 0, ErrBufTooSmall - } - if r < 0 { - return 0, 0, errors.New("varint: encoded integer greater than 18446744073709551615 (uint64)") - } - if n > 4294967295 { - return 0, 0, errors.New("varint: encoded integer greater than 4294967295 (uint32)") - } - return uint32(n), r, nil -} - -// Unpack64 unpacks a VarInt into a uint64. It returns the extracted int, how many bytes were used and an error. -func Unpack64(blob []byte) (uint64, int, error) { - n, r := binary.Uvarint(blob) - if r == 0 { - return 0, 0, ErrBufTooSmall - } - if r < 0 { - return 0, 0, errors.New("varint: encoded integer greater than 18446744073709551615 (uint64)") - } - return n, r, nil -} diff --git a/base/formats/varint/varint_test.go b/base/formats/varint/varint_test.go deleted file mode 100644 index 9f2250ef7..000000000 --- a/base/formats/varint/varint_test.go +++ /dev/null @@ -1,141 +0,0 @@ -//nolint:gocognit -package varint - -import ( - "bytes" - "testing" -) - -func TestConversion(t *testing.T) { - t.Parallel() - - subjects := []struct { - intType uint8 - bytes []byte - integer uint64 - }{ - {8, []byte{0x00}, 0}, - {8, []byte{0x01}, 1}, - {8, []byte{0x7F}, 127}, - {8, []byte{0x80, 0x01}, 128}, - {8, []byte{0xFF, 0x01}, 255}, - - {16, []byte{0x80, 0x02}, 256}, - {16, []byte{0xFF, 0x7F}, 16383}, - {16, []byte{0x80, 0x80, 0x01}, 16384}, - {16, []byte{0xFF, 0xFF, 0x03}, 65535}, - - {32, []byte{0x80, 0x80, 0x04}, 65536}, - {32, []byte{0xFF, 0xFF, 0x7F}, 2097151}, - {32, []byte{0x80, 0x80, 0x80, 0x01}, 2097152}, - {32, []byte{0xFF, 0xFF, 0xFF, 0x07}, 16777215}, - {32, []byte{0x80, 0x80, 0x80, 0x08}, 16777216}, - {32, []byte{0xFF, 0xFF, 0xFF, 0x7F}, 268435455}, - {32, []byte{0x80, 0x80, 0x80, 0x80, 0x01}, 268435456}, - {32, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0x0F}, 4294967295}, - - {64, []byte{0x80, 0x80, 0x80, 0x80, 0x10}, 4294967296}, - {64, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0x7F}, 34359738367}, - {64, []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x01}, 34359738368}, - {64, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x1F}, 1099511627775}, - {64, []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x20}, 1099511627776}, - {64, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F}, 4398046511103}, - {64, []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01}, 4398046511104}, - {64, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x3F}, 281474976710655}, - {64, []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x40}, 281474976710656}, - {64, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F}, 562949953421311}, - {64, []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01}, 562949953421312}, - {64, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F}, 72057594037927935}, - - {64, []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01}, 72057594037927936}, - {64, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F}, 9223372036854775807}, - - {64, []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01}, 9223372036854775808}, - {64, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01}, 18446744073709551615}, - } - - for _, subject := range subjects { - - actualInteger, _, err := Unpack64(subject.bytes) - if err != nil || actualInteger != subject.integer { - t.Errorf("Unpack64 %d: expected %d, actual %d", subject.bytes, subject.integer, actualInteger) - } - actualBytes := Pack64(subject.integer) - if err != nil || !bytes.Equal(actualBytes, subject.bytes) { - t.Errorf("Pack64 %d: expected %d, actual %d", subject.integer, subject.bytes, actualBytes) - } - - if subject.intType <= 32 { - actualInteger, _, err := Unpack32(subject.bytes) - if err != nil || actualInteger != uint32(subject.integer) { - t.Errorf("Unpack32 %d: expected %d, actual %d", subject.bytes, subject.integer, actualInteger) - } - actualBytes := Pack32(uint32(subject.integer)) - if err != nil || !bytes.Equal(actualBytes, subject.bytes) { - t.Errorf("Pack32 %d: expected %d, actual %d", subject.integer, subject.bytes, actualBytes) - } - } - - if subject.intType <= 16 { - actualInteger, _, err := Unpack16(subject.bytes) - if err != nil || actualInteger != uint16(subject.integer) { - t.Errorf("Unpack16 %d: expected %d, actual %d", subject.bytes, subject.integer, actualInteger) - } - actualBytes := Pack16(uint16(subject.integer)) - if err != nil || !bytes.Equal(actualBytes, subject.bytes) { - t.Errorf("Pack16 %d: expected %d, actual %d", subject.integer, subject.bytes, actualBytes) - } - } - - if subject.intType <= 8 { - actualInteger, _, err := Unpack8(subject.bytes) - if err != nil || actualInteger != uint8(subject.integer) { - t.Errorf("Unpack8 %d: expected %d, actual %d", subject.bytes, subject.integer, actualInteger) - } - actualBytes := Pack8(uint8(subject.integer)) - if err != nil || !bytes.Equal(actualBytes, subject.bytes) { - t.Errorf("Pack8 %d: expected %d, actual %d", subject.integer, subject.bytes, actualBytes) - } - } - - } -} - -func TestFails(t *testing.T) { - t.Parallel() - - subjects := []struct { - intType uint8 - bytes []byte - }{ - {32, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01}}, - {64, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x02}}, - {64, []byte{0xFF}}, - } - - for _, subject := range subjects { - - if subject.intType == 64 { - _, _, err := Unpack64(subject.bytes) - if err == nil { - t.Errorf("Unpack64 %d: expected error while unpacking.", subject.bytes) - } - } - - _, _, err := Unpack32(subject.bytes) - if err == nil { - t.Errorf("Unpack32 %d: expected error while unpacking.", subject.bytes) - } - - _, _, err = Unpack16(subject.bytes) - if err == nil { - t.Errorf("Unpack16 %d: expected error while unpacking.", subject.bytes) - } - - _, _, err = Unpack8(subject.bytes) - if err == nil { - t.Errorf("Unpack8 %d: expected error while unpacking.", subject.bytes) - } - - } -} diff --git a/cmds/notifier/notify.go b/cmds/notifier/notify.go index d78c36109..48e117c07 100644 --- a/cmds/notifier/notify.go +++ b/cmds/notifier/notify.go @@ -7,9 +7,9 @@ import ( "time" "github.com/safing/portmaster/base/api/client" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/log" pbnotify "github.com/safing/portmaster/base/notifications" + "github.com/safing/structures/dsd" ) const ( diff --git a/cmds/notifier/spn.go b/cmds/notifier/spn.go index 30fa18f81..0b49d63c9 100644 --- a/cmds/notifier/spn.go +++ b/cmds/notifier/spn.go @@ -7,8 +7,8 @@ import ( "github.com/tevino/abool" "github.com/safing/portmaster/base/api/client" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/log" + "github.com/safing/structures/dsd" ) const ( diff --git a/cmds/notifier/subsystems.go b/cmds/notifier/subsystems.go index 38390cfd2..587d6d84e 100644 --- a/cmds/notifier/subsystems.go +++ b/cmds/notifier/subsystems.go @@ -4,8 +4,8 @@ import ( "sync" "github.com/safing/portmaster/base/api/client" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/log" + "github.com/safing/structures/dsd" ) const ( diff --git a/cmds/portmaster-start/logs.go b/cmds/portmaster-start/logs.go index 280ef8cc8..f2f514b97 100644 --- a/cmds/portmaster-start/logs.go +++ b/cmds/portmaster-start/logs.go @@ -11,9 +11,9 @@ import ( "github.com/spf13/cobra" "github.com/safing/portmaster/base/database/record" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/info" "github.com/safing/structures/container" + "github.com/safing/structures/dsd" ) func initializeLogFile(logFilePath string, identifier string, version string) *os.File { diff --git a/service/intel/filterlists/decoder.go b/service/intel/filterlists/decoder.go index 4083a237f..a0b580f77 100644 --- a/service/intel/filterlists/decoder.go +++ b/service/intel/filterlists/decoder.go @@ -8,8 +8,8 @@ import ( "fmt" "io" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/utils" + "github.com/safing/structures/dsd" ) type listEntry struct { diff --git a/service/intel/filterlists/index.go b/service/intel/filterlists/index.go index 74770fe27..4b59adde4 100644 --- a/service/intel/filterlists/index.go +++ b/service/intel/filterlists/index.go @@ -9,10 +9,10 @@ import ( "github.com/safing/portmaster/base/database" "github.com/safing/portmaster/base/database/record" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/updater" "github.com/safing/portmaster/service/updates" + "github.com/safing/structures/dsd" ) // the following definitions are copied from the intelhub repository diff --git a/service/netquery/manager.go b/service/netquery/manager.go index 31f2f7f8d..787da764c 100644 --- a/service/netquery/manager.go +++ b/service/netquery/manager.go @@ -7,10 +7,10 @@ import ( "time" "github.com/safing/portmaster/base/database/record" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/runtime" "github.com/safing/portmaster/service/network" + "github.com/safing/structures/dsd" ) type ( diff --git a/service/netquery/runtime_query_runner.go b/service/netquery/runtime_query_runner.go index 0d09b51d7..bb3ffbad8 100644 --- a/service/netquery/runtime_query_runner.go +++ b/service/netquery/runtime_query_runner.go @@ -7,10 +7,10 @@ import ( "strings" "github.com/safing/portmaster/base/database/record" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/runtime" "github.com/safing/portmaster/service/netquery/orm" + "github.com/safing/structures/dsd" ) // RuntimeQueryRunner provides a simple interface for the runtime database diff --git a/service/profile/api.go b/service/profile/api.go index e23a4c851..ba8b24941 100644 --- a/service/profile/api.go +++ b/service/profile/api.go @@ -8,9 +8,9 @@ import ( "strings" "github.com/safing/portmaster/base/api" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/utils" "github.com/safing/portmaster/service/profile/binmeta" + "github.com/safing/structures/dsd" ) func registerAPIEndpoints() error { diff --git a/service/sync/setting_single.go b/service/sync/setting_single.go index c738c1021..9566fe0f8 100644 --- a/service/sync/setting_single.go +++ b/service/sync/setting_single.go @@ -9,8 +9,8 @@ import ( "github.com/safing/portmaster/base/api" "github.com/safing/portmaster/base/config" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/service/profile" + "github.com/safing/structures/dsd" ) // SingleSettingExport holds an export of a single setting. diff --git a/service/sync/util.go b/service/sync/util.go index 7fd3fb0c2..9d2cc9443 100644 --- a/service/sync/util.go +++ b/service/sync/util.go @@ -10,8 +10,8 @@ import ( "github.com/safing/jess/filesig" "github.com/safing/portmaster/base/api" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/structures/container" + "github.com/safing/structures/dsd" ) // Type is the type of an export. diff --git a/spn/access/client.go b/spn/access/client.go index fddae23df..ca5c8f200 100644 --- a/spn/access/client.go +++ b/spn/access/client.go @@ -9,10 +9,10 @@ import ( "time" "github.com/safing/portmaster/base/database" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/access/account" "github.com/safing/portmaster/spn/access/token" + "github.com/safing/structures/dsd" ) // Client URLs. diff --git a/spn/access/storage.go b/spn/access/storage.go index 617d3c66a..04f1a5fe1 100644 --- a/spn/access/storage.go +++ b/spn/access/storage.go @@ -9,9 +9,9 @@ import ( "github.com/safing/portmaster/base/database" "github.com/safing/portmaster/base/database/query" "github.com/safing/portmaster/base/database/record" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/access/token" + "github.com/safing/structures/dsd" ) func loadTokens() { diff --git a/spn/access/token/pblind.go b/spn/access/token/pblind.go index 97aa18922..1342a2831 100644 --- a/spn/access/token/pblind.go +++ b/spn/access/token/pblind.go @@ -13,8 +13,8 @@ import ( "github.com/mr-tron/base58" "github.com/rot256/pblind" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/structures/container" + "github.com/safing/structures/dsd" ) const pblindSecretSize = 32 diff --git a/spn/access/token/request_test.go b/spn/access/token/request_test.go index e5525a412..ffc22f095 100644 --- a/spn/access/token/request_test.go +++ b/spn/access/token/request_test.go @@ -4,7 +4,7 @@ import ( "testing" "time" - "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/structures/dsd" ) func TestFull(t *testing.T) { diff --git a/spn/access/token/scramble.go b/spn/access/token/scramble.go index c6ef236f3..083f2fd1d 100644 --- a/spn/access/token/scramble.go +++ b/spn/access/token/scramble.go @@ -7,7 +7,7 @@ import ( "github.com/mr-tron/base58" "github.com/safing/jess/lhash" - "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/structures/dsd" ) const ( diff --git a/spn/cabin/verification.go b/spn/cabin/verification.go index 4b1c0e48e..f60af514b 100644 --- a/spn/cabin/verification.go +++ b/spn/cabin/verification.go @@ -6,9 +6,9 @@ import ( "fmt" "github.com/safing/jess" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/rng" "github.com/safing/portmaster/spn/hub" + "github.com/safing/structures/dsd" ) var ( diff --git a/spn/captain/bootstrap.go b/spn/captain/bootstrap.go index 4ccb3370e..d6e8e68ec 100644 --- a/spn/captain/bootstrap.go +++ b/spn/captain/bootstrap.go @@ -7,11 +7,11 @@ import ( "io/fs" "os" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/hub" "github.com/safing/portmaster/spn/navigator" + "github.com/safing/structures/dsd" ) // BootstrapFile is used for sideloading bootstrap data. diff --git a/spn/captain/op_gossip.go b/spn/captain/op_gossip.go index de5edaa4d..b80fefa72 100644 --- a/spn/captain/op_gossip.go +++ b/spn/captain/op_gossip.go @@ -3,13 +3,13 @@ package captain import ( "time" - "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/docks" "github.com/safing/portmaster/spn/hub" "github.com/safing/portmaster/spn/terminal" "github.com/safing/structures/container" + "github.com/safing/structures/varint" ) // GossipOpType is the type ID of the gossip operation. diff --git a/spn/captain/op_gossip_query.go b/spn/captain/op_gossip_query.go index 27d605a8d..d6e6ad3ea 100644 --- a/spn/captain/op_gossip_query.go +++ b/spn/captain/op_gossip_query.go @@ -5,7 +5,6 @@ import ( "strings" "time" - "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/conf" @@ -13,6 +12,7 @@ import ( "github.com/safing/portmaster/spn/hub" "github.com/safing/portmaster/spn/terminal" "github.com/safing/structures/container" + "github.com/safing/structures/varint" ) // GossipQueryOpType is the type ID of the gossip query operation. diff --git a/spn/crew/op_connect.go b/spn/crew/op_connect.go index 394c4fc57..119126ede 100644 --- a/spn/crew/op_connect.go +++ b/spn/crew/op_connect.go @@ -10,7 +10,6 @@ import ( "sync/atomic" "time" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/network/netutils" @@ -18,6 +17,7 @@ import ( "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/terminal" "github.com/safing/structures/container" + "github.com/safing/structures/dsd" ) // ConnectOpType is the type ID for the connection operation. diff --git a/spn/crew/op_ping.go b/spn/crew/op_ping.go index eb8b240f4..43d433648 100644 --- a/spn/crew/op_ping.go +++ b/spn/crew/op_ping.go @@ -4,10 +4,10 @@ import ( "crypto/subtle" "time" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/rng" "github.com/safing/portmaster/spn/terminal" "github.com/safing/structures/container" + "github.com/safing/structures/dsd" ) const ( diff --git a/spn/docks/bandwidth_test.go b/spn/docks/bandwidth_test.go index 3599ce9c6..f41728637 100644 --- a/spn/docks/bandwidth_test.go +++ b/spn/docks/bandwidth_test.go @@ -6,9 +6,9 @@ import ( "github.com/tevino/abool" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/spn/terminal" "github.com/safing/structures/container" + "github.com/safing/structures/dsd" ) func TestEffectiveBandwidth(t *testing.T) { //nolint:paralleltest // Run alone. diff --git a/spn/docks/crane.go b/spn/docks/crane.go index 4d2a84bda..9c6d11830 100644 --- a/spn/docks/crane.go +++ b/spn/docks/crane.go @@ -12,7 +12,6 @@ import ( "github.com/tevino/abool" "github.com/safing/jess" - "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/rng" "github.com/safing/portmaster/service/mgr" @@ -21,6 +20,7 @@ import ( "github.com/safing/portmaster/spn/ships" "github.com/safing/portmaster/spn/terminal" "github.com/safing/structures/container" + "github.com/safing/structures/varint" ) const ( diff --git a/spn/docks/crane_init.go b/spn/docks/crane_init.go index 9bf08773d..3648913c8 100644 --- a/spn/docks/crane_init.go +++ b/spn/docks/crane_init.go @@ -5,13 +5,13 @@ import ( "time" "github.com/safing/jess" - "github.com/safing/portmaster/base/formats/dsd" - "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/base/info" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/terminal" "github.com/safing/structures/container" + "github.com/safing/structures/dsd" + "github.com/safing/structures/varint" ) /* diff --git a/spn/docks/crane_verify.go b/spn/docks/crane_verify.go index f6f976a7a..679899244 100644 --- a/spn/docks/crane_verify.go +++ b/spn/docks/crane_verify.go @@ -6,10 +6,10 @@ import ( "fmt" "time" - "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/spn/cabin" "github.com/safing/portmaster/spn/terminal" "github.com/safing/structures/container" + "github.com/safing/structures/varint" ) const ( diff --git a/spn/docks/op_capacity.go b/spn/docks/op_capacity.go index e26aec352..86f38eb36 100644 --- a/spn/docks/op_capacity.go +++ b/spn/docks/op_capacity.go @@ -7,11 +7,11 @@ import ( "github.com/tevino/abool" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/terminal" "github.com/safing/structures/container" + "github.com/safing/structures/dsd" ) const ( diff --git a/spn/docks/op_latency.go b/spn/docks/op_latency.go index 59681fc12..6b68ad814 100644 --- a/spn/docks/op_latency.go +++ b/spn/docks/op_latency.go @@ -5,12 +5,12 @@ import ( "fmt" "time" - "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/rng" "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/terminal" "github.com/safing/structures/container" + "github.com/safing/structures/varint" ) const ( diff --git a/spn/docks/op_sync_state.go b/spn/docks/op_sync_state.go index c5303544e..ff1d8e730 100644 --- a/spn/docks/op_sync_state.go +++ b/spn/docks/op_sync_state.go @@ -4,11 +4,11 @@ import ( "context" "time" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/spn/conf" "github.com/safing/portmaster/spn/terminal" "github.com/safing/structures/container" + "github.com/safing/structures/dsd" ) // SyncStateOpType is the type ID of the sync state operation. diff --git a/spn/docks/op_whoami.go b/spn/docks/op_whoami.go index 53ca914d6..664f85a55 100644 --- a/spn/docks/op_whoami.go +++ b/spn/docks/op_whoami.go @@ -3,9 +3,9 @@ package docks import ( "time" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/spn/terminal" "github.com/safing/structures/container" + "github.com/safing/structures/dsd" ) const ( diff --git a/spn/hub/update.go b/spn/hub/update.go index 0d6a9efd9..f77dbbfd4 100644 --- a/spn/hub/update.go +++ b/spn/hub/update.go @@ -8,10 +8,10 @@ import ( "github.com/safing/jess" "github.com/safing/jess/lhash" "github.com/safing/portmaster/base/database" - "github.com/safing/portmaster/base/formats/dsd" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/network/netutils" "github.com/safing/structures/container" + "github.com/safing/structures/dsd" ) var ( diff --git a/spn/hub/update_test.go b/spn/hub/update_test.go index 0f8d667e3..69810390a 100644 --- a/spn/hub/update_test.go +++ b/spn/hub/update_test.go @@ -5,7 +5,7 @@ import ( "testing" "github.com/safing/jess" - "github.com/safing/portmaster/base/formats/dsd" + "github.com/safing/structures/dsd" ) func TestHubUpdate(t *testing.T) { diff --git a/spn/terminal/control_flow.go b/spn/terminal/control_flow.go index af328fe88..8da522ee7 100644 --- a/spn/terminal/control_flow.go +++ b/spn/terminal/control_flow.go @@ -7,8 +7,8 @@ import ( "sync/atomic" "time" - "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/service/mgr" + "github.com/safing/structures/varint" ) // FlowControl defines the flow control interface. diff --git a/spn/terminal/errors.go b/spn/terminal/errors.go index bc762bb3b..2d209b487 100644 --- a/spn/terminal/errors.go +++ b/spn/terminal/errors.go @@ -5,7 +5,7 @@ import ( "errors" "fmt" - "github.com/safing/portmaster/base/formats/varint" + "github.com/safing/structures/varint" ) // Error is a terminal error. diff --git a/spn/terminal/init.go b/spn/terminal/init.go index a437e9f5f..ee8633e28 100644 --- a/spn/terminal/init.go +++ b/spn/terminal/init.go @@ -4,11 +4,11 @@ import ( "context" "github.com/safing/jess" - "github.com/safing/portmaster/base/formats/dsd" - "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/spn/cabin" "github.com/safing/portmaster/spn/hub" "github.com/safing/structures/container" + "github.com/safing/structures/dsd" + "github.com/safing/structures/varint" ) /* diff --git a/spn/terminal/msgtypes.go b/spn/terminal/msgtypes.go index fba9d3235..c158e93bd 100644 --- a/spn/terminal/msgtypes.go +++ b/spn/terminal/msgtypes.go @@ -1,8 +1,8 @@ package terminal import ( - "github.com/safing/portmaster/base/formats/varint" "github.com/safing/structures/container" + "github.com/safing/structures/varint" ) /* diff --git a/spn/terminal/operation_counter.go b/spn/terminal/operation_counter.go index 687ade535..5ac2dacd3 100644 --- a/spn/terminal/operation_counter.go +++ b/spn/terminal/operation_counter.go @@ -5,11 +5,11 @@ import ( "sync" "time" - "github.com/safing/portmaster/base/formats/dsd" - "github.com/safing/portmaster/base/formats/varint" "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/mgr" "github.com/safing/structures/container" + "github.com/safing/structures/dsd" + "github.com/safing/structures/varint" ) // CounterOpType is the type ID for the Counter Operation. From a5cf25f3e75f02cfba2c699a85034993e43f8191 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 9 Aug 2024 13:25:06 +0200 Subject: [PATCH 53/56] Improve expansion test --- spn/docks/module_test.go | 3 +++ spn/docks/terminal_expansion_test.go | 12 ++++-------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/spn/docks/module_test.go b/spn/docks/module_test.go index 80acc96d2..e8cf6cd0c 100644 --- a/spn/docks/module_test.go +++ b/spn/docks/module_test.go @@ -7,6 +7,7 @@ import ( "github.com/safing/portmaster/base/config" "github.com/safing/portmaster/base/database/dbmodule" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/base/metrics" "github.com/safing/portmaster/base/rng" "github.com/safing/portmaster/service/core/base" @@ -46,6 +47,8 @@ func (stub *testInstance) Stopping() bool { func (stub *testInstance) SetCmdLineOperation(f func() error) {} func runTest(m *testing.M) error { + log.Start() + ds, err := config.InitializeUnitTestDataroot("test-docks") if err != nil { return fmt.Errorf("failed to initialize dataroot: %w", err) diff --git a/spn/docks/terminal_expansion_test.go b/spn/docks/terminal_expansion_test.go index 656f5fdcc..4d9029de4 100644 --- a/spn/docks/terminal_expansion_test.go +++ b/spn/docks/terminal_expansion_test.go @@ -76,7 +76,7 @@ func testExpansion( //nolint:maintidx,thelper serverCountTo uint64, inParallel bool, ) { - testID += fmt.Sprintf(":encrypt=%v,flowType=%d,parallel=%v", terminalOpts.Encrypt, terminalOpts.FlowControl, inParallel) + testID += fmt.Sprintf(":encrypt=%v,flowCtrl=%d,parallel=%v", terminalOpts.Encrypt, terminalOpts.FlowControl, inParallel) var identity2, identity3, identity4 *cabin.Identity var connectedHub2, connectedHub3, connectedHub4 *hub.Hub @@ -94,10 +94,11 @@ func testExpansion( //nolint:maintidx,thelper var crane1, crane2to1, crane2to3, crane3to2, crane3to4, crane4 *Crane var craneWg sync.WaitGroup - started := time.Now() - craneCtx, cancelCraneCtx := context.WithCancel(context.Background()) craneWg.Add(6) + craneCtx, cancelCraneCtx := context.WithCancel(context.Background()) + defer cancelCraneCtx() + go func() { var err error crane1, err = NewCrane(ship1to2, connectedHub2, nil) @@ -291,9 +292,6 @@ func testExpansion( //nolint:maintidx,thelper op1.Wait() } - // Wait for double the time, so that the counters can complete in both directions. - time.Sleep(time.Since(started)) - // Signal completion. close(finished) @@ -301,8 +299,6 @@ func testExpansion( //nolint:maintidx,thelper // if we succeeded. time.Sleep(100 * time.Millisecond) - cancelCraneCtx() - // Check errors. if op1.Error != nil { t.Fatalf("crane test %s counter op1 failed: %s", testID, op1.Error) From 06862b77db2d36869054714616d7a598e1867464 Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 9 Aug 2024 14:53:57 +0200 Subject: [PATCH 54/56] Fix linter warnings --- base/api/authentication.go | 9 ++---- base/api/main.go | 5 ++-- base/database/registry.go | 22 -------------- base/rng/rng.go | 9 ++++-- base/rng/test/main.go | 2 ++ base/runtime/provider.go | 2 +- base/updater/fetch.go | 1 - cmds/observation-hub/apprise.go | 6 +++- cmds/observation-hub/main.go | 7 +++++ cmds/observation-hub/observe.go | 4 +++ service/compat/module.go | 3 ++ service/compat/notify.go | 2 +- service/core/base/module.go | 4 +++ service/core/core.go | 9 +++--- service/firewall/interception/module.go | 4 +++ service/firewall/module.go | 1 - service/instance.go | 5 ++-- service/mgr/states.go | 2 +- service/mgr/workermgr_test.go | 24 ++++++++++------ service/netenv/init_test.go | 2 +- service/network/connection_handler.go | 3 +- service/resolver/main.go | 10 +++++-- service/resolver/main_test.go | 2 +- service/updates/main.go | 2 -- spn/captain/hooks.go | 2 +- spn/captain/intel.go | 2 +- spn/captain/module.go | 22 +++++++++----- spn/captain/navigation.go | 2 +- spn/captain/public.go | 8 +++--- spn/docks/hub_import.go | 2 +- spn/docks/module.go | 4 +++ spn/docks/module_test.go | 38 ++++++++++++------------- spn/instance.go | 3 +- spn/sluice/packet_listener.go | 3 +- spn/sluice/udp_listener.go | 3 +- spn/terminal/module.go | 6 +++- spn/terminal/module_test.go | 29 +++++++++---------- spn/unit/scheduler.go | 3 +- spn/unit/unit_test.go | 3 +- 39 files changed, 154 insertions(+), 116 deletions(-) diff --git a/base/api/authentication.go b/base/api/authentication.go index 73e659f53..2ea09f7aa 100644 --- a/base/api/authentication.go +++ b/base/api/authentication.go @@ -1,7 +1,6 @@ package api import ( - "context" "encoding/base64" "errors" "fmt" @@ -116,8 +115,8 @@ func (sess *session) Refresh(ttl time.Duration) { // permission for an API handler. The returned permission is the required // permission for the request to proceed. type AuthenticatedHandler interface { - ReadPermission(*http.Request) Permission - WritePermission(*http.Request) Permission + ReadPermission(r *http.Request) Permission + WritePermission(r *http.Request) Permission } // SetAuthenticator sets an authenticator function for the API endpoint. If none is set, all requests will be permitted. @@ -351,7 +350,7 @@ func checkAPIKey(r *http.Request) *AuthToken { return token } -func updateAPIKeys(_ context.Context) error { +func updateAPIKeys() { apiKeysLock.Lock() defer apiKeysLock.Unlock() @@ -443,8 +442,6 @@ func updateAPIKeys(_ context.Context) error { return nil }) } - - return nil } func checkSessionCookie(r *http.Request) *AuthToken { diff --git a/base/api/main.go b/base/api/main.go index cc2187788..d8e52f851 100644 --- a/base/api/main.go +++ b/base/api/main.go @@ -52,10 +52,11 @@ func prep() error { func start() error { startServer() - _ = updateAPIKeys(module.mgr.Ctx()) + updateAPIKeys() module.instance.Config().EventConfigChange.AddCallback("update API keys", func(wc *mgr.WorkerCtx, s struct{}) (cancel bool, err error) { - return false, updateAPIKeys(wc.Ctx()) + updateAPIKeys() + return false, nil }) // start api auth token cleaner diff --git a/base/database/registry.go b/base/database/registry.go index 6c8f59e3b..44c5d9703 100644 --- a/base/database/registry.go +++ b/base/database/registry.go @@ -6,18 +6,9 @@ import ( "regexp" "sync" "time" - - "github.com/tevino/abool" -) - -const ( - registryFileName = "databases.json" ) var ( - registryPersistence = abool.NewBool(false) - writeRegistrySoon = abool.NewBool(false) - registry = make(map[string]*Database) registryLock sync.Mutex @@ -33,17 +24,14 @@ func Register(db *Database) (*Database, error) { defer registryLock.Unlock() registeredDB, ok := registry[db.Name] - save := false if ok { // update database if registeredDB.Description != db.Description { registeredDB.Description = db.Description - save = true } if registeredDB.ShadowDelete != db.ShadowDelete { registeredDB.ShadowDelete = db.ShadowDelete - save = true } } else { // register new database @@ -57,13 +45,6 @@ func Register(db *Database) (*Database, error) { db.LastLoaded = time.Time{} registry[db.Name] = db - save = true - } - - if save && registryPersistence.IsSet() { - if ok { - registeredDB.Updated() - } } if ok { @@ -80,9 +61,6 @@ func getDatabase(name string) (*Database, error) { if !ok { return nil, fmt.Errorf(`database "%s" not registered`, name) } - if time.Now().Add(-24 * time.Hour).After(registeredDB.LastLoaded) { - writeRegistrySoon.Set() - } registeredDB.Loaded() return registeredDB, nil diff --git a/base/rng/rng.go b/base/rng/rng.go index 01ad5dd1a..8ff1a7fe4 100644 --- a/base/rng/rng.go +++ b/base/rng/rng.go @@ -10,10 +10,12 @@ import ( "sync/atomic" "github.com/aead/serpent" - "github.com/safing/portmaster/service/mgr" "github.com/seehuhn/fortuna" + + "github.com/safing/portmaster/service/mgr" ) +// Rng is a random number generator. type Rng struct { mgr *mgr.Manager @@ -27,7 +29,6 @@ var ( rngCipher = "aes" // Possible values: "aes", "serpent". - ) func newCipher(key []byte) (cipher.Block, error) { @@ -41,10 +42,12 @@ func newCipher(key []byte) (cipher.Block, error) { } } +// Manager returns the module manager. func (r *Rng) Manager() *mgr.Manager { return r.mgr } +// Start starts the module. func (r *Rng) Start() error { rngLock.Lock() defer rngLock.Unlock() @@ -84,6 +87,7 @@ func (r *Rng) Start() error { return nil } +// Stop stops the module. func (r *Rng) Stop() error { return nil } @@ -93,6 +97,7 @@ var ( shimLoaded atomic.Bool ) +// New returns a new rng. func New(instance instance) (*Rng, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") diff --git a/base/rng/test/main.go b/base/rng/test/main.go index 0778a1376..7096ac1ae 100644 --- a/base/rng/test/main.go +++ b/base/rng/test/main.go @@ -20,6 +20,7 @@ import ( "github.com/safing/portmaster/service/mgr" ) +// Test tests the rng. type Test struct { mgr *mgr.Manager @@ -207,6 +208,7 @@ func noise(ctx *mgr.WorkerCtx) error { } } +// New returns a new rng test. func New(instance instance) (*Test, error) { if !shimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") diff --git a/base/runtime/provider.go b/base/runtime/provider.go index 5951e89b6..7eb0dd78d 100644 --- a/base/runtime/provider.go +++ b/base/runtime/provider.go @@ -31,7 +31,7 @@ type ( // should be returned. It is guaranteed that the key of // the record passed to Set is prefixed with the key used // to register the value provider. - Set(record.Record) (record.Record, error) + Set(r record.Record) (record.Record, error) // Get should return one or more records that match keyOrPrefix. // keyOrPrefix is guaranteed to be at least the prefix used to // register the ValueProvider. diff --git a/base/updater/fetch.go b/base/updater/fetch.go index e3c397e47..f324709d4 100644 --- a/base/updater/fetch.go +++ b/base/updater/fetch.go @@ -49,7 +49,6 @@ func (reg *ResourceRegistry) fetchFile(ctx context.Context, client *http.Client, rv.versionedSigPath(), rv.SigningMetadata(), tries, ) - if err != nil { switch rv.resource.VerificationOptions.DownloadPolicy { case SignaturePolicyRequire: diff --git a/cmds/observation-hub/apprise.go b/cmds/observation-hub/apprise.go index c64d3dedb..c9bf0e65e 100644 --- a/cmds/observation-hub/apprise.go +++ b/cmds/observation-hub/apprise.go @@ -19,20 +19,24 @@ import ( "github.com/safing/portmaster/service/mgr" ) +// Apprise is the apprise notification module. type Apprise struct { mgr *mgr.Manager instance instance } +// Manager returns the module manager. func (a *Apprise) Manager() *mgr.Manager { return a.mgr } +// Start starts the module. func (a *Apprise) Start() error { return startApprise() } +// Stop stops the module. func (a *Apprise) Stop() error { return nil } @@ -276,7 +280,7 @@ func getCountryInfo(code string) geoip.CountryInfo { // } // } -// New returns a new Apprise module. +// NewApprise returns a new Apprise module. func NewApprise(instance instance) (*Observer, error) { if !appriseShimLoaded.CompareAndSwap(false, true) { return nil, errors.New("only one instance allowed") diff --git a/cmds/observation-hub/main.go b/cmds/observation-hub/main.go index ecc0cd55e..0a96df6e1 100644 --- a/cmds/observation-hub/main.go +++ b/cmds/observation-hub/main.go @@ -73,6 +73,13 @@ func main() { } instance.AddModule(observer) + _, err = NewApprise(instance) + if err != nil { + fmt.Printf("error creating an instance: create apprise module: %s\n", err) + os.Exit(2) + } + instance.AddModule(observer) + // Execute command line operation, if requested or available. switch { case !execCmdLine: diff --git a/cmds/observation-hub/observe.go b/cmds/observation-hub/observe.go index 354de200f..ff2d16976 100644 --- a/cmds/observation-hub/observe.go +++ b/cmds/observation-hub/observe.go @@ -20,19 +20,23 @@ import ( "github.com/safing/portmaster/spn/navigator" ) +// Observer is the network observer module. type Observer struct { mgr *mgr.Manager instance instance } +// Manager returns the module manager. func (o *Observer) Manager() *mgr.Manager { return o.mgr } +// Start starts the module. func (o *Observer) Start() error { return startObserver() } +// Stop stops the module. func (o *Observer) Stop() error { return nil } diff --git a/service/compat/module.go b/service/compat/module.go index cc40e29aa..5ac97b511 100644 --- a/service/compat/module.go +++ b/service/compat/module.go @@ -13,6 +13,7 @@ import ( "github.com/safing/portmaster/service/resolver" ) +// Compat is the compatibility check module. type Compat struct { mgr *mgr.Manager instance instance @@ -23,10 +24,12 @@ type Compat struct { states *mgr.StateMgr } +// Manager returns the module manager. func (u *Compat) Manager() *mgr.Manager { return u.mgr } +// States returns the module state manager. func (u *Compat) States() *mgr.StateMgr { return u.states } diff --git a/service/compat/notify.go b/service/compat/notify.go index 3cb64bf5b..86999719c 100644 --- a/service/compat/notify.go +++ b/service/compat/notify.go @@ -111,7 +111,7 @@ func systemCompatOrManualDNSIssue() *systemIssue { return manualDNSSetupRequired } -func (issue *systemIssue) notify(err error) { +func (issue *systemIssue) notify(err error) { //nolint // TODO: Should we use the error? systemIssueNotificationLock.Lock() defer systemIssueNotificationLock.Unlock() diff --git a/service/core/base/module.go b/service/core/base/module.go index 1082e2bfd..664920152 100644 --- a/service/core/base/module.go +++ b/service/core/base/module.go @@ -7,15 +7,18 @@ import ( "github.com/safing/portmaster/service/mgr" ) +// Base is the base module. type Base struct { mgr *mgr.Manager instance instance } +// Manager returns the module manager. func (b *Base) Manager() *mgr.Manager { return b.mgr } +// Start starts the module. func (b *Base) Start() error { startProfiling() registerLogCleaner() @@ -23,6 +26,7 @@ func (b *Base) Start() error { return nil } +// Stop stops the module. func (b *Base) Stop() error { return nil } diff --git a/service/core/core.go b/service/core/core.go index 33025565c..e14789d08 100644 --- a/service/core/core.go +++ b/service/core/core.go @@ -17,11 +17,7 @@ import ( _ "github.com/safing/portmaster/service/ui" ) -const ( - eventShutdown = "shutdown" - eventRestart = "restart" -) - +// Core is the core service module. type Core struct { m *mgr.Manager instance instance @@ -30,14 +26,17 @@ type Core struct { EventRestart *mgr.EventMgr[struct{}] } +// Manager returns the manager. func (c *Core) Manager() *mgr.Manager { return c.m } +// Start starts the module. func (c *Core) Start() error { return start() } +// Stop stops the module. func (c *Core) Stop() error { return nil } diff --git a/service/firewall/interception/module.go b/service/firewall/interception/module.go index 72a66abe4..072eb3b5e 100644 --- a/service/firewall/interception/module.go +++ b/service/firewall/interception/module.go @@ -10,19 +10,23 @@ import ( "github.com/safing/portmaster/service/network/packet" ) +// Interception is the packet interception module. type Interception struct { mgr *mgr.Manager instance instance } +// Manager returns the module manager. func (i *Interception) Manager() *mgr.Manager { return i.mgr } +// Start starts the module. func (i *Interception) Start() error { return start() } +// Stop stops the module. func (i *Interception) Stop() error { return stop() } diff --git a/service/firewall/module.go b/service/firewall/module.go index 1f716cd08..131d4cacb 100644 --- a/service/firewall/module.go +++ b/service/firewall/module.go @@ -31,7 +31,6 @@ func (ss *stringSliceFlag) Set(value string) error { return nil } -// module *modules.Module var allowedClients stringSliceFlag type Firewall struct { diff --git a/service/instance.go b/service/instance.go index fd39a0ce4..d55dccd8b 100644 --- a/service/instance.go +++ b/service/instance.go @@ -101,7 +101,7 @@ type Instance struct { } // New returns a new Portmaster service instance. -func New(svcCfg *ServiceConfig) (*Instance, error) { +func New(svcCfg *ServiceConfig) (*Instance, error) { //nolint:maintidx // Create instance to pass it to modules. instance := &Instance{} instance.ctx, instance.cancelCtx = context.WithCancel(context.Background()) @@ -580,7 +580,8 @@ func (i *Instance) Stop() error { // RestartExitCode will instruct portmaster-start to restart the process immediately, potentially with a new version. const RestartExitCode = 23 -// Shutdown asynchronously stops the instance. +// Restart asynchronously restarts the instance. +// This only works if the underlying system/process supports this. func (i *Instance) Restart() { // Send a restart event, give it 10ms extra to propagate. i.core.EventRestart.Submit(struct{}{}) diff --git a/service/mgr/states.go b/service/mgr/states.go index 9049bb298..c80d6771a 100644 --- a/service/mgr/states.go +++ b/service/mgr/states.go @@ -37,7 +37,7 @@ type State struct { // Optional. Type StateType - // Time is the time when the state was created or the originating incident occured. + // Time is the time when the state was created or the originating incident occurred. // Optional, will be set to current time if not set. Time time.Time diff --git a/service/mgr/workermgr_test.go b/service/mgr/workermgr_test.go index 758d14499..b57447078 100644 --- a/service/mgr/workermgr_test.go +++ b/service/mgr/workermgr_test.go @@ -7,6 +7,8 @@ import ( ) func TestWorkerMgrDelay(t *testing.T) { + t.Parallel() + m := New("DelayTest") value := atomic.Bool{} @@ -21,7 +23,7 @@ func TestWorkerMgrDelay(t *testing.T) { // Check if value is set after 1 second and not before or after. iterations := 0 for !value.Load() { - iterations += 1 + iterations++ time.Sleep(10 * time.Millisecond) } @@ -32,6 +34,8 @@ func TestWorkerMgrDelay(t *testing.T) { } func TestWorkerMgrRepeat(t *testing.T) { + t.Parallel() + m := New("RepeatTest") value := atomic.Bool{} @@ -47,7 +51,7 @@ func TestWorkerMgrRepeat(t *testing.T) { for range 10 { iterations := 0 for !value.Load() { - iterations += 1 + iterations++ time.Sleep(10 * time.Millisecond) } @@ -61,7 +65,9 @@ func TestWorkerMgrRepeat(t *testing.T) { } } -func TestWorkerMgrDelayAndRepeat(t *testing.T) { +func TestWorkerMgrDelayAndRepeat(t *testing.T) { //nolint:dupl + t.Parallel() + m := New("DelayAndRepeatTest") value := atomic.Bool{} @@ -75,7 +81,7 @@ func TestWorkerMgrDelayAndRepeat(t *testing.T) { iterations := 0 for !value.Load() { - iterations += 1 + iterations++ time.Sleep(10 * time.Millisecond) } @@ -91,7 +97,7 @@ func TestWorkerMgrDelayAndRepeat(t *testing.T) { for range 10 { iterations = 0 for !value.Load() { - iterations += 1 + iterations++ time.Sleep(10 * time.Millisecond) } @@ -105,7 +111,9 @@ func TestWorkerMgrDelayAndRepeat(t *testing.T) { } } -func TestWorkerMgrRepeatAndDelay(t *testing.T) { +func TestWorkerMgrRepeatAndDelay(t *testing.T) { //nolint:dupl + t.Parallel() + m := New("RepeatAndDelayTest") value := atomic.Bool{} @@ -119,7 +127,7 @@ func TestWorkerMgrRepeatAndDelay(t *testing.T) { iterations := 0 for !value.Load() { - iterations += 1 + iterations++ time.Sleep(10 * time.Millisecond) } @@ -134,7 +142,7 @@ func TestWorkerMgrRepeatAndDelay(t *testing.T) { for range 10 { iterations := 0 for !value.Load() { - iterations += 1 + iterations++ time.Sleep(10 * time.Millisecond) } diff --git a/service/netenv/init_test.go b/service/netenv/init_test.go index 17ef12403..b747111b2 100644 --- a/service/netenv/init_test.go +++ b/service/netenv/init_test.go @@ -91,7 +91,7 @@ func runTest(m *testing.M) error { _, err = New(stub) if err != nil { - return fmt.Errorf("failed to initialize module %s", err) + return fmt.Errorf("failed to initialize module %w", err) } m.Run() diff --git a/service/network/connection_handler.go b/service/network/connection_handler.go index 2d2613260..3293c608d 100644 --- a/service/network/connection_handler.go +++ b/service/network/connection_handler.go @@ -4,10 +4,11 @@ import ( "context" "time" + "github.com/tevino/abool" + "github.com/safing/portmaster/base/log" "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/network/packet" - "github.com/tevino/abool" ) // SetFirewallHandler sets the firewall handler for this link, and starts a diff --git a/service/resolver/main.go b/service/resolver/main.go index f708d725c..8a43d12be 100644 --- a/service/resolver/main.go +++ b/service/resolver/main.go @@ -21,7 +21,8 @@ import ( "github.com/safing/portmaster/service/netenv" ) -type ResolverModule struct { +// ResolverModule is the DNS resolver module. +type ResolverModule struct { //nolint mgr *mgr.Manager instance instance @@ -31,18 +32,22 @@ type ResolverModule struct { states *mgr.StateMgr } +// Manager returns the module manager. func (rm *ResolverModule) Manager() *mgr.Manager { return rm.mgr } +// States returns the module state manager. func (rm *ResolverModule) States() *mgr.StateMgr { return rm.states } +// Start starts the module. func (rm *ResolverModule) Start() error { return start() } +// Stop stops the module. func (rm *ResolverModule) Stop() error { return nil } @@ -109,8 +114,7 @@ func start() error { module.instance.NetEnv().EventNetworkChange.AddCallback( "check failing resolvers", func(wc *mgr.WorkerCtx, _ struct{}) (bool, error) { - checkFailingResolvers(wc) - return false, nil + return false, checkFailingResolvers(wc) }) module.suggestUsingStaleCacheTask = module.mgr.NewWorkerMgr("suggest using stale cache", suggestUsingStaleCacheTask, nil) diff --git a/service/resolver/main_test.go b/service/resolver/main_test.go index 99cb7b05b..4efc8eb89 100644 --- a/service/resolver/main_test.go +++ b/service/resolver/main_test.go @@ -58,7 +58,7 @@ func (stub *testInstance) Shutdown() {} func (stub *testInstance) SetCmdLineOperation(f func() error) {} -func (i *testInstance) GetEventSPNConnected() *mgr.EventMgr[struct{}] { +func (stub *testInstance) GetEventSPNConnected() *mgr.EventMgr[struct{}] { return mgr.NewEventMgr[struct{}]("spn connect", nil) } diff --git a/service/updates/main.go b/service/updates/main.go index 279529c97..bb942993a 100644 --- a/service/updates/main.go +++ b/service/updates/main.go @@ -43,13 +43,11 @@ const ( ) var ( - // module *modules.Module registry *updater.ResourceRegistry userAgentFromFlag string updateServerFromFlag string - // updateTask *modules.Task updateASAP bool disableTaskSchedule bool diff --git a/spn/captain/hooks.go b/spn/captain/hooks.go index 119f6d811..d21bd67b0 100644 --- a/spn/captain/hooks.go +++ b/spn/captain/hooks.go @@ -32,7 +32,7 @@ func handleCraneUpdate(crane *docks.Crane) { func updateConnectionStatus() { // Delay updating status for a better chance to combine multiple changes. - module.maintainPublicStatus.Delay(maintainStatusUpdateDelay) + module.statusUpdater.Delay(maintainStatusUpdateDelay) // Check if we lost all connections and trigger a pending restart if we did. for _, crane := range docks.GetAllAssignedCranes() { diff --git a/spn/captain/intel.go b/spn/captain/intel.go index f6bbe4fb8..c71ce663c 100644 --- a/spn/captain/intel.go +++ b/spn/captain/intel.go @@ -34,7 +34,7 @@ func registerIntelUpdateHook() error { return nil } -func updateSPNIntel(ctx context.Context, _ interface{}) (err error) { +func updateSPNIntel(_ context.Context, _ interface{}) (err error) { intelResourceUpdateLock.Lock() defer intelResourceUpdateLock.Unlock() diff --git a/spn/captain/module.go b/spn/captain/module.go index b25040624..1f82d7632 100644 --- a/spn/captain/module.go +++ b/spn/captain/module.go @@ -23,38 +23,44 @@ import ( "github.com/safing/portmaster/spn/ships" ) -const controlledFailureExitCode = 24 - // SPNConnectedEvent is the name of the event that is fired when the SPN has connected and is ready. const SPNConnectedEvent = "spn connect" +// Captain is the main module of the SPN. type Captain struct { mgr *mgr.Manager instance instance - healthCheckTicker *mgr.SleepyTicker - maintainPublicStatus *mgr.WorkerMgr + healthCheckTicker *mgr.SleepyTicker + + publicIdentityUpdater *mgr.WorkerMgr + statusUpdater *mgr.WorkerMgr states *mgr.StateMgr EventSPNConnected *mgr.EventMgr[struct{}] } +// Manager returns the module manager. func (c *Captain) Manager() *mgr.Manager { return c.mgr } +// States returns the module states. func (c *Captain) States() *mgr.StateMgr { return c.states } +// Start starts the module. func (c *Captain) Start() error { return start() } +// Stop stops the module. func (c *Captain) Stop() error { return stop() } +// SetSleep sets the sleep mode of the module. func (c *Captain) SetSleep(enabled bool) { if c.healthCheckTicker != nil { c.healthCheckTicker.SetSleep(enabled) @@ -225,9 +231,11 @@ func New(instance instance) (*Captain, error) { mgr: m, instance: instance, - states: mgr.NewStateMgr(m), - EventSPNConnected: mgr.NewEventMgr[struct{}](SPNConnectedEvent, m), - maintainPublicStatus: m.NewWorkerMgr("maintain public status", maintainPublicStatus, nil), + states: mgr.NewStateMgr(m), + EventSPNConnected: mgr.NewEventMgr[struct{}](SPNConnectedEvent, m), + + publicIdentityUpdater: m.NewWorkerMgr("maintain public identity", maintainPublicIdentity, nil), + statusUpdater: m.NewWorkerMgr("maintain public status", maintainPublicStatus, nil), } if err := module.prep(); err != nil { diff --git a/spn/captain/navigation.go b/spn/captain/navigation.go index ac757616f..096a8af96 100644 --- a/spn/captain/navigation.go +++ b/spn/captain/navigation.go @@ -210,7 +210,7 @@ func connectToHomeHub(wCtx *mgr.WorkerCtx, dst *hub.Hub) error { return nil } -func optimizeNetwork(ctx *mgr.WorkerCtx) error { //, task *modules.Task) error { +func optimizeNetwork(ctx *mgr.WorkerCtx) error { if publicIdentity == nil { return nil } diff --git a/spn/captain/public.go b/spn/captain/public.go index 427ff9aac..fdb707858 100644 --- a/spn/captain/public.go +++ b/spn/captain/public.go @@ -87,11 +87,11 @@ func loadPublicIdentity() (err error) { } func prepPublicIdentityMgmt() error { - module.maintainPublicStatus.Repeat(maintainStatusInterval) + module.statusUpdater.Repeat(maintainStatusInterval) module.instance.Config().EventConfigChange.AddCallback("update public identity from config", func(wc *mgr.WorkerCtx, s struct{}) (cancel bool, err error) { - module.maintainPublicStatus.Delay(5 * time.Minute) + module.publicIdentityUpdater.Delay(5 * time.Minute) return false, nil }) return nil @@ -99,10 +99,10 @@ func prepPublicIdentityMgmt() error { // TriggerHubStatusMaintenance queues the Hub status update task to be executed. func TriggerHubStatusMaintenance() { - module.maintainPublicStatus.Go() + module.statusUpdater.Go() } -func maintainPublicIdentity(ctx *mgr.WorkerCtx) error { +func maintainPublicIdentity(_ *mgr.WorkerCtx) error { changed, err := publicIdentity.MaintainAnnouncement(nil, false) if err != nil { return fmt.Errorf("failed to maintain announcement: %w", err) diff --git a/spn/docks/hub_import.go b/spn/docks/hub_import.go index ff2981337..c9295c641 100644 --- a/spn/docks/hub_import.go +++ b/spn/docks/hub_import.go @@ -33,7 +33,7 @@ func ImportAndVerifyHubInfo(ctx context.Context, hubID string, announcementData, var hubKnown, hubChanged bool if announcementData != nil { hubFromMsg, known, changed, err := hub.ApplyAnnouncement(nil, announcementData, mapName, scope, false) - if err != nil && firstErr == nil { + if err != nil { firstErr = terminal.ErrInternalError.With("failed to apply announcement: %w", err) } if known { diff --git a/spn/docks/module.go b/spn/docks/module.go index 4ee6242f4..ceb97a5a3 100644 --- a/spn/docks/module.go +++ b/spn/docks/module.go @@ -12,19 +12,23 @@ import ( _ "github.com/safing/portmaster/spn/access" ) +// Docks handles connections to other network participants. type Docks struct { mgr *mgr.Manager instance instance } +// Manager returns the module manager. func (d *Docks) Manager() *mgr.Manager { return d.mgr } +// Start starts the module. func (d *Docks) Start() error { return start() } +// Stop stops the module. func (d *Docks) Stop() error { return stopAllCranes() } diff --git a/spn/docks/module_test.go b/spn/docks/module_test.go index e8cf6cd0c..47e34b436 100644 --- a/spn/docks/module_test.go +++ b/spn/docks/module_test.go @@ -47,7 +47,7 @@ func (stub *testInstance) Stopping() bool { func (stub *testInstance) SetCmdLineOperation(f func() error) {} func runTest(m *testing.M) error { - log.Start() + _ = log.Start() ds, err := config.InitializeUnitTestDataroot("test-docks") if err != nil { @@ -63,77 +63,77 @@ func runTest(m *testing.M) error { // Init instance.db, err = dbmodule.New(instance) if err != nil { - return fmt.Errorf("failed to create database module: %w\n", err) + return fmt.Errorf("failed to create database module: %w", err) } instance.config, err = config.New(instance) if err != nil { - return fmt.Errorf("failed to create config module: %w\n", err) + return fmt.Errorf("failed to create config module: %w", err) } instance.metrics, err = metrics.New(instance) if err != nil { - return fmt.Errorf("failed to create metrics module: %w\n", err) + return fmt.Errorf("failed to create metrics module: %w", err) } instance.rng, err = rng.New(instance) if err != nil { - return fmt.Errorf("failed to create rng module: %w\n", err) + return fmt.Errorf("failed to create rng module: %w", err) } instance.base, err = base.New(instance) if err != nil { - return fmt.Errorf("failed to create base module: %w\n", err) + return fmt.Errorf("failed to create base module: %w", err) } instance.access, err = access.New(instance) if err != nil { - return fmt.Errorf("failed to create access module: %w\n", err) + return fmt.Errorf("failed to create access module: %w", err) } instance.terminal, err = terminal.New(instance) if err != nil { - return fmt.Errorf("failed to create terminal module: %w\n", err) + return fmt.Errorf("failed to create terminal module: %w", err) } instance.cabin, err = cabin.New(instance) if err != nil { - return fmt.Errorf("failed to create cabin module: %w\n", err) + return fmt.Errorf("failed to create cabin module: %w", err) } module, err = New(instance) if err != nil { - return fmt.Errorf("failed to create docks module: %w\n", err) + return fmt.Errorf("failed to create docks module: %w", err) } // Start err = instance.db.Start() if err != nil { - return fmt.Errorf("failed to start db module: %w\n", err) + return fmt.Errorf("failed to start db module: %w", err) } err = instance.config.Start() if err != nil { - return fmt.Errorf("failed to start config module: %w\n", err) + return fmt.Errorf("failed to start config module: %w", err) } err = instance.metrics.Start() if err != nil { - return fmt.Errorf("failed to start metrics module: %w\n", err) + return fmt.Errorf("failed to start metrics module: %w", err) } err = instance.rng.Start() if err != nil { - return fmt.Errorf("failed to start rng module: %w\n", err) + return fmt.Errorf("failed to start rng module: %w", err) } err = instance.base.Start() if err != nil { - return fmt.Errorf("failed to start base module: %w\n", err) + return fmt.Errorf("failed to start base module: %w", err) } err = instance.access.Start() if err != nil { - return fmt.Errorf("failed to start access module: %w\n", err) + return fmt.Errorf("failed to start access module: %w", err) } err = instance.terminal.Start() if err != nil { - return fmt.Errorf("failed to start terminal module: %w\n", err) + return fmt.Errorf("failed to start terminal module: %w", err) } err = instance.cabin.Start() if err != nil { - return fmt.Errorf("failed to start cabin module: %w\n", err) + return fmt.Errorf("failed to start cabin module: %w", err) } err = module.Start() if err != nil { - return fmt.Errorf("failed to start docks module: %w\n", err) + return fmt.Errorf("failed to start docks module: %w", err) } m.Run() diff --git a/spn/instance.go b/spn/instance.go index 842ec7b5c..69114d385 100644 --- a/spn/instance.go +++ b/spn/instance.go @@ -380,7 +380,8 @@ func (i *Instance) Stop() error { // RestartExitCode will instruct portmaster-start to restart the process immediately, potentially with a new version. const RestartExitCode = 23 -// Shutdown asynchronously stops the instance. +// Restart asynchronously restarts the instance. +// This only works if the underlying system/process supports this. func (i *Instance) Restart() { // Send a restart event, give it 10ms extra to propagate. i.core.EventRestart.Submit(struct{}{}) diff --git a/spn/sluice/packet_listener.go b/spn/sluice/packet_listener.go index b3c1f026d..a20fcefff 100644 --- a/spn/sluice/packet_listener.go +++ b/spn/sluice/packet_listener.go @@ -7,8 +7,9 @@ import ( "sync/atomic" "time" - "github.com/safing/portmaster/service/mgr" "github.com/tevino/abool" + + "github.com/safing/portmaster/service/mgr" ) // PacketListener is a listener for packet based protocols. diff --git a/spn/sluice/udp_listener.go b/spn/sluice/udp_listener.go index 31f83e077..2f0ab03f2 100644 --- a/spn/sluice/udp_listener.go +++ b/spn/sluice/udp_listener.go @@ -8,10 +8,11 @@ import ( "sync/atomic" "time" - "github.com/safing/portmaster/service/mgr" "github.com/tevino/abool" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" + + "github.com/safing/portmaster/service/mgr" ) const onWindows = runtime.GOOS == "windows" diff --git a/spn/terminal/module.go b/spn/terminal/module.go index 46424a5ac..a0867f996 100644 --- a/spn/terminal/module.go +++ b/spn/terminal/module.go @@ -12,19 +12,23 @@ import ( "github.com/safing/portmaster/spn/unit" ) -type TerminalModule struct { +// TerminalModule is the command multiplexing module. +type TerminalModule struct { //nolint:golint mgr *mgr.Manager instance instance } +// Manager returns the module manager. func (s *TerminalModule) Manager() *mgr.Manager { return s.mgr } +// Start starts the module. func (s *TerminalModule) Start() error { return start() } +// Stop stops the module. func (s *TerminalModule) Stop() error { return nil } diff --git a/spn/terminal/module_test.go b/spn/terminal/module_test.go index 93ebeafa9..d7e6f1d49 100644 --- a/spn/terminal/module_test.go +++ b/spn/terminal/module_test.go @@ -53,62 +53,61 @@ func runTest(m *testing.M) error { instance := &testInstance{} instance.db, err = dbmodule.New(instance) if err != nil { - return fmt.Errorf("failed to create database module: %w\n", err) + return fmt.Errorf("failed to create database module: %w", err) } instance.config, err = config.New(instance) if err != nil { - return fmt.Errorf("failed to create config module: %w\n", err) + return fmt.Errorf("failed to create config module: %w", err) } instance.metrics, err = metrics.New(instance) if err != nil { - return fmt.Errorf("failed to create metrics module: %w\n", err) + return fmt.Errorf("failed to create metrics module: %w", err) } instance.rng, err = rng.New(instance) if err != nil { - return fmt.Errorf("failed to create rng module: %w\n", err) + return fmt.Errorf("failed to create rng module: %w", err) } instance.base, err = base.New(instance) if err != nil { - return fmt.Errorf("failed to create base module: %w\n", err) + return fmt.Errorf("failed to create base module: %w", err) } instance.cabin, err = cabin.New(instance) if err != nil { - return fmt.Errorf("failed to create cabin module: %w\n", err) + return fmt.Errorf("failed to create cabin module: %w", err) } _, err = New(instance) if err != nil { - fmt.Printf("failed to create module: %s\n", err) - os.Exit(0) + return fmt.Errorf("failed to create module: %w", err) } // Start err = instance.db.Start() if err != nil { - return fmt.Errorf("failed to start db module: %w\n", err) + return fmt.Errorf("failed to start db module: %w", err) } err = instance.config.Start() if err != nil { - return fmt.Errorf("failed to start config module: %w\n", err) + return fmt.Errorf("failed to start config module: %w", err) } err = instance.metrics.Start() if err != nil { - return fmt.Errorf("failed to start metrics module: %w\n", err) + return fmt.Errorf("failed to start metrics module: %w", err) } err = instance.rng.Start() if err != nil { - return fmt.Errorf("failed to start rng module: %w\n", err) + return fmt.Errorf("failed to start rng module: %w", err) } err = instance.base.Start() if err != nil { - return fmt.Errorf("failed to start base module: %w\n", err) + return fmt.Errorf("failed to start base module: %w", err) } err = instance.cabin.Start() if err != nil { - return fmt.Errorf("failed to start cabin module: %w\n", err) + return fmt.Errorf("failed to start cabin module: %w", err) } err = module.Start() if err != nil { - return fmt.Errorf("failed to start docks module: %w\n", err) + return fmt.Errorf("failed to start docks module: %w", err) } m.Run() diff --git a/spn/unit/scheduler.go b/spn/unit/scheduler.go index 9027241fb..fc49a0dfb 100644 --- a/spn/unit/scheduler.go +++ b/spn/unit/scheduler.go @@ -7,8 +7,9 @@ import ( "sync/atomic" "time" - "github.com/safing/portmaster/service/mgr" "github.com/tevino/abool" + + "github.com/safing/portmaster/service/mgr" ) const ( diff --git a/spn/unit/unit_test.go b/spn/unit/unit_test.go index a77ea899e..636f1231f 100644 --- a/spn/unit/unit_test.go +++ b/spn/unit/unit_test.go @@ -8,8 +8,9 @@ import ( "testing" "time" - "github.com/safing/portmaster/service/mgr" "github.com/stretchr/testify/assert" + + "github.com/safing/portmaster/service/mgr" ) func TestUnit(t *testing.T) { //nolint:paralleltest From 45b50fe6cf9ec37bcf593e2a2725f135cbf57efd Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 9 Aug 2024 14:54:05 +0200 Subject: [PATCH 55/56] Fix interception module on windows --- .../interception/interception_windows.go | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/service/firewall/interception/interception_windows.go b/service/firewall/interception/interception_windows.go index bd36ffa1a..cb97376b7 100644 --- a/service/firewall/interception/interception_windows.go +++ b/service/firewall/interception/interception_windows.go @@ -1,13 +1,13 @@ package interception import ( - "context" "fmt" "time" "github.com/safing/portmaster/base/log" kext1 "github.com/safing/portmaster/service/firewall/interception/windowskext" kext2 "github.com/safing/portmaster/service/firewall/interception/windowskext2" + "github.com/safing/portmaster/service/mgr" "github.com/safing/portmaster/service/network" "github.com/safing/portmaster/service/network/packet" "github.com/safing/portmaster/service/updates" @@ -46,25 +46,25 @@ func startInterception(packets chan packet.Packet) error { kext1.SetKextService(kext2.GetKextServiceHandle(), kextFile.Path()) // Start packet handler. - module.StartServiceWorker("kext packet handler", 0, func(ctx context.Context) error { - kext1.Handler(ctx, packets) + module.mgr.Go("kext packet handler", func(w *mgr.WorkerCtx) error { + kext1.Handler(w.Ctx(), packets) return nil }) // Start bandwidth stats monitor. - module.StartServiceWorker("kext bandwidth stats monitor", 0, func(ctx context.Context) error { - return kext1.BandwidthStatsWorker(ctx, 1*time.Second, BandwidthUpdates) + module.mgr.Go("kext bandwidth stats monitor", func(w *mgr.WorkerCtx) error { + return kext1.BandwidthStatsWorker(w.Ctx(), 1*time.Second, BandwidthUpdates) }) } else { // Start packet handler. - module.StartServiceWorker("kext packet handler", 0, func(ctx context.Context) error { - kext2.Handler(ctx, packets, BandwidthUpdates) + module.mgr.Go("kext packet handler", func(w *mgr.WorkerCtx) error { + kext2.Handler(w.Ctx(), packets, BandwidthUpdates) return nil }) // Start bandwidth stats monitor. - module.StartServiceWorker("kext bandwidth request worker", 0, func(ctx context.Context) error { + module.mgr.Go("kext bandwidth request worker", func(w *mgr.WorkerCtx) error { timer := time.NewTicker(1 * time.Second) defer timer.Stop() for { @@ -74,7 +74,7 @@ func startInterception(packets chan packet.Packet) error { if err != nil { return err } - case <-ctx.Done(): + case <-w.Done(): return nil } @@ -82,7 +82,7 @@ func startInterception(packets chan packet.Packet) error { }) // Start kext logging. The worker will periodically send request to the kext to send logs. - module.StartServiceWorker("kext log request worker", 0, func(ctx context.Context) error { + module.mgr.Go("kext log request worker", func(w *mgr.WorkerCtx) error { timer := time.NewTicker(1 * time.Second) defer timer.Stop() for { @@ -92,14 +92,14 @@ func startInterception(packets chan packet.Packet) error { if err != nil { return err } - case <-ctx.Done(): + case <-w.Done(): return nil } } }) - module.StartServiceWorker("kext clean ended connection worker", 0, func(ctx context.Context) error { + module.mgr.Go("kext clean ended connection worker", func(w *mgr.WorkerCtx) error { timer := time.NewTicker(30 * time.Second) defer timer.Stop() for { @@ -109,7 +109,7 @@ func startInterception(packets chan packet.Packet) error { if err != nil { return err } - case <-ctx.Done(): + case <-w.Done(): return nil } From b2b51ffded6865b06a99aa290b34c6dc622e4fc5 Mon Sep 17 00:00:00 2001 From: Vladimir Stoilov Date: Fri, 9 Aug 2024 18:06:20 +0300 Subject: [PATCH 56/56] Fix linter errors --- base/container/container_test.go | 8 +++--- base/database/record/meta-bench_test.go | 22 +++++++-------- base/log/output.go | 2 ++ base/rng/fullfeed_test.go | 2 +- base/rng/get_test.go | 2 +- base/updater/get.go | 2 +- base/updater/updating.go | 10 +++---- base/utils/call_limiter_test.go | 8 +++--- base/utils/onceagain_test.go | 6 ++--- base/utils/renameio/symlink_test.go | 2 +- base/utils/renameio/tempfile_linux_test.go | 4 +-- base/utils/stablepool_test.go | 24 ++++++++--------- cmds/observation-hub/apprise.go | 2 +- service/firewall/api.go | 4 +-- .../interception/ebpf/bandwidth/interface.go | 12 ++++----- .../ebpf/connection_listener/worker.go | 2 +- .../firewall/interception/ebpf/exec/exec.go | 2 +- service/netenv/location.go | 2 +- service/netquery/orm/decoder.go | 6 ++--- service/netquery/orm/encoder.go | 2 +- service/netquery/orm/query_runner.go | 2 +- service/netquery/orm/schema_builder.go | 2 +- service/netquery/query_test.go | 27 +++++++++---------- service/network/netutils/ip.go | 2 +- service/network/ports.go | 2 +- service/profile/fingerprint_test.go | 2 +- service/resolver/resolver-tcp.go | 2 +- service/resolver/resolver_test.go | 4 +-- service/status/status.go | 6 ++--- spn/access/client_test.go | 2 +- spn/access/token/pblind.go | 8 +++--- spn/access/token/pblind_test.go | 10 +++---- spn/cabin/keys.go | 2 +- spn/cabin/keys_test.go | 2 +- spn/crew/op_connect_test.go | 3 +-- spn/docks/module.go | 2 +- spn/hub/hub.go | 2 +- spn/navigator/findnearest_test.go | 4 +-- spn/navigator/findroutes_test.go | 6 ++--- spn/navigator/map_test.go | 10 +++---- spn/navigator/module.go | 2 +- spn/navigator/optimize.go | 2 +- spn/patrol/http.go | 2 +- spn/ships/connection_test.go | 5 ++-- spn/ships/http_shared.go | 5 ++-- spn/ships/testship_test.go | 2 +- spn/terminal/session_test.go | 11 ++++---- spn/unit/unit_test.go | 4 +-- 48 files changed, 123 insertions(+), 134 deletions(-) diff --git a/base/container/container_test.go b/base/container/container_test.go index bc8608b5f..b8a314057 100644 --- a/base/container/container_test.go +++ b/base/container/container_test.go @@ -29,7 +29,7 @@ func TestContainerDataHandling(t *testing.T) { c1c := c1.carbonCopy() c2 := New() - for i := 0; i < len(testData); i++ { + for range len(testData) { oneByte := make([]byte, 1) c1c.WriteToSlice(oneByte) c2.Append(oneByte) @@ -48,7 +48,7 @@ func TestContainerDataHandling(t *testing.T) { c3c = c3.carbonCopy() d5 := make([]byte, len(testData)) - for i := 0; i < len(testData); i++ { + for i := range len(testData) { c3c.WriteToSlice(d5[i : i+1]) } @@ -61,7 +61,7 @@ func TestContainerDataHandling(t *testing.T) { } c8 := New(testDataSplitted...) - for i := 0; i < 110; i++ { + for range 110 { c8.Prepend(nil) } c8.clean() @@ -155,7 +155,7 @@ func TestContainerBlockHandling(t *testing.T) { c1c := c1.carbonCopy() c2 := New(nil) - for i := 0; i < c1.Length(); i++ { + for range c1.Length() { oneByte := make([]byte, 1) c1c.WriteToSlice(oneByte) c2.Append(oneByte) diff --git a/base/database/record/meta-bench_test.go b/base/database/record/meta-bench_test.go index bfcf05173..fd9d3b6b1 100644 --- a/base/database/record/meta-bench_test.go +++ b/base/database/record/meta-bench_test.go @@ -36,27 +36,27 @@ var testMeta = &Meta{ } func BenchmarkAllocateBytes(b *testing.B) { - for i := 0; i < b.N; i++ { + for range b.N { _ = make([]byte, 33) } } func BenchmarkAllocateStruct1(b *testing.B) { - for i := 0; i < b.N; i++ { + for range b.N { var newMeta Meta _ = newMeta } } func BenchmarkAllocateStruct2(b *testing.B) { - for i := 0; i < b.N; i++ { + for range b.N { _ = Meta{} } } func BenchmarkMetaSerializeContainer(b *testing.B) { // Start benchmark - for i := 0; i < b.N; i++ { + for range b.N { c := container.New() c.AppendNumber(uint64(testMeta.Created)) c.AppendNumber(uint64(testMeta.Modified)) @@ -98,7 +98,7 @@ func BenchmarkMetaUnserializeContainer(b *testing.B) { b.ResetTimer() // Start benchmark - for i := 0; i < b.N; i++ { + for range b.N { var newMeta Meta var err error var num uint64 @@ -152,7 +152,7 @@ func BenchmarkMetaUnserializeContainer(b *testing.B) { func BenchmarkMetaSerializeVarInt(b *testing.B) { // Start benchmark - for i := 0; i < b.N; i++ { + for range b.N { encoded := make([]byte, 33) offset := 0 data := varint.Pack64(uint64(testMeta.Created)) @@ -231,7 +231,7 @@ func BenchmarkMetaUnserializeVarInt(b *testing.B) { b.ResetTimer() // Start benchmark - for i := 0; i < b.N; i++ { + for range b.N { var newMeta Meta offset = 0 @@ -284,7 +284,7 @@ func BenchmarkMetaUnserializeVarInt(b *testing.B) { } func BenchmarkMetaSerializeWithCodegen(b *testing.B) { - for i := 0; i < b.N; i++ { + for range b.N { _, err := testMeta.GenCodeMarshal(nil) if err != nil { b.Errorf("failed to serialize with codegen: %s", err) @@ -305,7 +305,7 @@ func BenchmarkMetaUnserializeWithCodegen(b *testing.B) { b.ResetTimer() // Start benchmark - for i := 0; i < b.N; i++ { + for range b.N { var newMeta Meta _, err := newMeta.GenCodeUnmarshal(encodedData) if err != nil { @@ -316,7 +316,7 @@ func BenchmarkMetaUnserializeWithCodegen(b *testing.B) { } func BenchmarkMetaSerializeWithDSDJSON(b *testing.B) { - for i := 0; i < b.N; i++ { + for range b.N { _, err := dsd.Dump(testMeta, dsd.JSON) if err != nil { b.Errorf("failed to serialize with DSD/JSON: %s", err) @@ -337,7 +337,7 @@ func BenchmarkMetaUnserializeWithDSDJSON(b *testing.B) { b.ResetTimer() // Start benchmark - for i := 0; i < b.N; i++ { + for range b.N { var newMeta Meta _, err := dsd.Load(encodedData, &newMeta) if err != nil { diff --git a/base/log/output.go b/base/log/output.go index f80fe3519..a2947dc56 100644 --- a/base/log/output.go +++ b/base/log/output.go @@ -129,6 +129,8 @@ func writerManager() { } } +// defer should be able to edit the err. So naked return is required. +// nolint:golint,nakedret func writer() (err error) { defer func() { // recover from panic diff --git a/base/rng/fullfeed_test.go b/base/rng/fullfeed_test.go index c3da8373f..c3188c22a 100644 --- a/base/rng/fullfeed_test.go +++ b/base/rng/fullfeed_test.go @@ -7,7 +7,7 @@ import ( func TestFullFeeder(t *testing.T) { t.Parallel() - for i := 0; i < 10; i++ { + for range 10 { go func() { rngFeeder <- []byte{0} }() diff --git a/base/rng/get_test.go b/base/rng/get_test.go index b92003747..f1afa6bf4 100644 --- a/base/rng/get_test.go +++ b/base/rng/get_test.go @@ -19,7 +19,7 @@ func TestNumberRandomness(t *testing.T) { var testSize uint64 = 10000 results := make([]uint64, int(subjects)) - for i := 0; i < int(subjects*testSize); i++ { + for range int(subjects * testSize) { n, err := Number(subjects - 1) if err != nil { t.Fatal(err) diff --git a/base/updater/get.go b/base/updater/get.go index eb09ba98f..d50d28b36 100644 --- a/base/updater/get.go +++ b/base/updater/get.go @@ -59,7 +59,7 @@ func (reg *ResourceRegistry) GetFile(identifier string) (*File, error) { // download file log.Tracef("%s: starting download of %s", reg.Name, file.versionedPath) client := &http.Client{} - for tries := 0; tries < 5; tries++ { + for tries := range 5 { err = reg.fetchFile(context.TODO(), client, file.version, tries) if err != nil { log.Tracef("%s: failed to download %s: %s, retrying (%d)", reg.Name, file.versionedPath, err, tries+1) diff --git a/base/updater/updating.go b/base/updater/updating.go index 23e3df455..cf87472e8 100644 --- a/base/updater/updating.go +++ b/base/updater/updating.go @@ -75,7 +75,7 @@ func (reg *ResourceRegistry) downloadIndex(ctx context.Context, client *http.Cli } // Download new index and signature. - for tries := 0; tries < 3; tries++ { + for tries := range 3 { // Index and signature need to be fetched together, so that they are // fetched from the same source. One source should always have a matching // index and signature. Backup sources may be behind a little. @@ -219,7 +219,7 @@ func (reg *ResourceRegistry) DownloadUpdates(ctx context.Context, includeManual rv.resource.Identifier, rv.VersionNumber, ) var err error - for tries := 0; tries < 3; tries++ { + for tries := range 3 { err = reg.fetchFile(ctx, client, rv, tries) if err == nil { // Update resource version state. @@ -249,7 +249,7 @@ func (reg *ResourceRegistry) DownloadUpdates(ctx context.Context, includeManual for _, rv := range missingSigs { var err error - for tries := 0; tries < 3; tries++ { + for tries := range 3 { err = reg.fetchMissingSig(ctx, client, rv, tries) if err == nil { // Update resource version state. @@ -338,10 +338,10 @@ func (reg *ResourceRegistry) GetPendingDownloads(manual, auto bool) (resources, }() } - slices.SortFunc[[]*ResourceVersion, *ResourceVersion](toUpdate, func(a, b *ResourceVersion) int { + slices.SortFunc(toUpdate, func(a, b *ResourceVersion) int { return strings.Compare(a.resource.Identifier, b.resource.Identifier) }) - slices.SortFunc[[]*ResourceVersion, *ResourceVersion](missingSigs, func(a, b *ResourceVersion) int { + slices.SortFunc(missingSigs, func(a, b *ResourceVersion) int { return strings.Compare(a.resource.Identifier, b.resource.Identifier) }) diff --git a/base/utils/call_limiter_test.go b/base/utils/call_limiter_test.go index 16bd1d5ab..3144644ed 100644 --- a/base/utils/call_limiter_test.go +++ b/base/utils/call_limiter_test.go @@ -21,9 +21,9 @@ func TestCallLimiter(t *testing.T) { // We are doing this without sleep in function, so dummy exec first to trigger first pause. oa.Do(func() {}) // Start - for i := 0; i < 10; i++ { + for range 10 { testWg.Add(100) - for i := 0; i < 100; i++ { + for range 100 { go func() { oa.Do(func() { if !executed.SetToIf(false, true) { @@ -48,7 +48,7 @@ func TestCallLimiter(t *testing.T) { // Choose values so that about 10 executions are expected var execs uint32 testWg.Add(200) - for i := 0; i < 200; i++ { + for range 200 { go func() { oa.Do(func() { atomic.AddUint32(&execs, 1) @@ -74,7 +74,7 @@ func TestCallLimiter(t *testing.T) { // Check if the limiter correctly handles panics. testWg.Add(100) - for i := 0; i < 100; i++ { + for range 100 { go func() { defer func() { _ = recover() diff --git a/base/utils/onceagain_test.go b/base/utils/onceagain_test.go index 5d4e4aaff..15bd6d4f3 100644 --- a/base/utils/onceagain_test.go +++ b/base/utils/onceagain_test.go @@ -17,9 +17,9 @@ func TestOnceAgain(t *testing.T) { var testWg sync.WaitGroup // One execution should gobble up the whole batch. - for i := 0; i < 10; i++ { + for range 10 { testWg.Add(100) - for i := 0; i < 100; i++ { + for range 100 { go func() { oa.Do(func() { if !executed.SetToIf(false, true) { @@ -38,7 +38,7 @@ func TestOnceAgain(t *testing.T) { // Choose values so that about 10 executions are expected var execs uint32 testWg.Add(100) - for i := 0; i < 100; i++ { + for range 100 { go func() { oa.Do(func() { atomic.AddUint32(&execs, 1) diff --git a/base/utils/renameio/symlink_test.go b/base/utils/renameio/symlink_test.go index a3a1b48db..a1603996d 100644 --- a/base/utils/renameio/symlink_test.go +++ b/base/utils/renameio/symlink_test.go @@ -25,7 +25,7 @@ func TestSymlink(t *testing.T) { t.Fatal(err) } - for i := 0; i < 2; i++ { + for range 2 { if err := Symlink("hello.txt", filepath.Join(d, "hi.txt")); err != nil { t.Fatal(err) } diff --git a/base/utils/renameio/tempfile_linux_test.go b/base/utils/renameio/tempfile_linux_test.go index 88ce025e0..a1e223059 100644 --- a/base/utils/renameio/tempfile_linux_test.go +++ b/base/utils/renameio/tempfile_linux_test.go @@ -95,9 +95,7 @@ func TestTempDir(t *testing.T) { }, } - for _, tt := range tests { - testCase := tt - + for _, testCase := range tests { t.Run(testCase.name, func(t *testing.T) { t.Parallel() diff --git a/base/utils/stablepool_test.go b/base/utils/stablepool_test.go index 0f8ed7262..f32206fc3 100644 --- a/base/utils/stablepool_test.go +++ b/base/utils/stablepool_test.go @@ -27,13 +27,12 @@ func TestStablePoolRealWorld(t *testing.T) { // cnt++ // testPool.Put(cnt) // } - for i := 0; i < 100; i++ { + for range 100 { // block round testWg.Add(1) // add workers testWorkerWg.Add(100) - for j := 0; j < 100; j++ { - k := j + for j := range 100 { go func() { // wait for round to start testWg.Wait() @@ -43,7 +42,7 @@ func TestStablePoolRealWorld(t *testing.T) { // "work" time.Sleep(5 * time.Microsecond) // re-insert 99% - if k%100 > 0 { + if j%100 > 0 { testPool.Put(x) } // mark as finished @@ -62,11 +61,11 @@ func TestStablePoolRealWorld(t *testing.T) { // optimal usage test optPool := &StablePool{} - for i := 0; i < 1000; i++ { - for j := 0; j < 100; j++ { + for range 1000 { + for j := range 100 { optPool.Put(j) } - for k := 0; k < 100; k++ { + for k := range 100 { assert.Equal(t, k, optPool.Get(), "should match") } } @@ -82,12 +81,11 @@ func TestStablePoolFuzzing(t *testing.T) { var fuzzWorkerWg sync.WaitGroup // start goroutines and wait fuzzWg.Add(1) - for i := 0; i < 1000; i++ { + for i := range 1000 { fuzzWorkerWg.Add(2) - j := i go func() { fuzzWg.Wait() - fuzzPool.Put(j) + fuzzPool.Put(i) fuzzWorkerWg.Done() }() go func() { @@ -107,13 +105,13 @@ func TestStablePoolBreaking(t *testing.T) { // try to break it breakPool := &StablePool{} - for i := 0; i < 10; i++ { - for j := 0; j < 100; j++ { + for range 10 { + for j := range 100 { breakPool.Put(nil) breakPool.Put(j) breakPool.Put(nil) } - for k := 0; k < 100; k++ { + for k := range 100 { assert.Equal(t, k, breakPool.Get(), "should match") } } diff --git a/cmds/observation-hub/apprise.go b/cmds/observation-hub/apprise.go index c9bf0e65e..0303de415 100644 --- a/cmds/observation-hub/apprise.go +++ b/cmds/observation-hub/apprise.go @@ -151,7 +151,7 @@ handleTag: // Send notification to apprise. var err error - for i := 0; i < 3; i++ { + for range 3 { // Try three times. err = appriseNotifier.Send(appriseModule.mgr.Ctx(), &apprise.Message{ Body: buf.String(), diff --git a/service/firewall/api.go b/service/firewall/api.go index c37dbf690..134f6f749 100644 --- a/service/firewall/api.go +++ b/service/firewall/api.go @@ -98,7 +98,7 @@ func apiAuthenticator(r *http.Request, s *http.Server) (token *api.AuthToken, er // It is important that this works, retry 5 times: every 500ms for 2.5s. var retry bool - for tries := 0; tries < 5; tries++ { + for range 5 { retry, err = authenticateAPIRequest( r.Context(), &packet.Info{ @@ -161,7 +161,7 @@ func authenticateAPIRequest(ctx context.Context, pktInfo *packet.Info) (retry bo // Find parent for up to two levels, if we don't match the path. checkLevels := 2 checkLevelsLoop: - for i := 0; i < checkLevels+1; i++ { + for i := range checkLevels + 1 { // Check for eligible path. switch proc.Pid { case process.UnidentifiedProcessID, process.SystemProcessID: diff --git a/service/firewall/interception/ebpf/bandwidth/interface.go b/service/firewall/interception/ebpf/bandwidth/interface.go index 124b21a42..e27989775 100644 --- a/service/firewall/interception/ebpf/bandwidth/interface.go +++ b/service/firewall/interception/ebpf/bandwidth/interface.go @@ -182,11 +182,11 @@ func convertArrayToIP(input [4]uint32, ipv6 bool) net.IP { addressBuf := make([]byte, 4) binary.LittleEndian.PutUint32(addressBuf, input[0]) return net.IP(addressBuf) + } else { + addressBuf := make([]byte, 16) + for i := range 4 { + binary.LittleEndian.PutUint32(addressBuf[i*4:i*4+4], input[i]) + } + return net.IP(addressBuf) } - - addressBuf := make([]byte, 16) - for i := 0; i < 4; i++ { - binary.LittleEndian.PutUint32(addressBuf[i*4:i*4+4], input[i]) - } - return net.IP(addressBuf) } diff --git a/service/firewall/interception/ebpf/connection_listener/worker.go b/service/firewall/interception/ebpf/connection_listener/worker.go index e7019c3fd..0768b73d6 100644 --- a/service/firewall/interception/ebpf/connection_listener/worker.go +++ b/service/firewall/interception/ebpf/connection_listener/worker.go @@ -169,7 +169,7 @@ func convertArrayToIPv4(input [4]uint32, ipVersion packet.IPVersion) net.IP { } addressBuf := make([]byte, 16) - for i := 0; i < 4; i++ { + for i := range 4 { binary.LittleEndian.PutUint32(addressBuf[i*4:i*4+4], input[i]) } return net.IP(addressBuf) diff --git a/service/firewall/interception/ebpf/exec/exec.go b/service/firewall/interception/ebpf/exec/exec.go index 3e5433442..d09fc1302 100644 --- a/service/firewall/interception/ebpf/exec/exec.go +++ b/service/firewall/interception/ebpf/exec/exec.go @@ -202,7 +202,7 @@ func (t *Tracer) Read() (*Event, error) { if argc > arglen { argc = arglen } - for i := 0; i < argc; i++ { + for i := range argc { str := unix.ByteSliceToString(rawEvent.Argv[i][:]) if strings.TrimSpace(str) != "" { ev.Argv = append(ev.Argv, str) diff --git a/service/netenv/location.go b/service/netenv/location.go index 6026da33c..71cdafefa 100644 --- a/service/netenv/location.go +++ b/service/netenv/location.go @@ -418,7 +418,7 @@ nextHop: // Send ICMP packet. // Try to send three times, as this can be flaky. sendICMP: - for i := 0; i < 3; i++ { + for range 3 { _, err = conn.WriteTo(pingPacket, locationTestingIPv4Addr) if err == nil { break sendICMP diff --git a/service/netquery/orm/decoder.go b/service/netquery/orm/decoder.go index 169c5e7ef..7f022908e 100644 --- a/service/netquery/orm/decoder.go +++ b/service/netquery/orm/decoder.go @@ -101,7 +101,7 @@ func DecodeStmt(ctx context.Context, schema *TableSchema, stmt Stmt, result inte // create a lookup map from field name (or sqlite:"" tag) // to the field name lm := make(map[string]string) - for i := 0; i < target.NumField(); i++ { + for i := range target.NumField() { fieldType := t.Field(i) // skip unexported fields @@ -114,7 +114,7 @@ func DecodeStmt(ctx context.Context, schema *TableSchema, stmt Stmt, result inte // iterate over all columns and assign them to the correct // fields - for i := 0; i < stmt.ColumnCount(); i++ { + for i := range stmt.ColumnCount() { colName := stmt.ColumnName(i) fieldName, ok := lm[colName] if !ok { @@ -246,7 +246,7 @@ func decodeIntoMap(_ context.Context, schema *TableSchema, stmt Stmt, mp *map[st *mp = make(map[string]interface{}) } - for i := 0; i < stmt.ColumnCount(); i++ { + for i := range stmt.ColumnCount() { var x interface{} colDef := schema.GetColumnDef(stmt.ColumnName(i)) diff --git a/service/netquery/orm/encoder.go b/service/netquery/orm/encoder.go index 8aa533878..9d5bd7ae1 100644 --- a/service/netquery/orm/encoder.go +++ b/service/netquery/orm/encoder.go @@ -33,7 +33,7 @@ func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg Encode res := make(map[string]interface{}, val.NumField()) - for i := 0; i < val.NumField(); i++ { + for i := range val.NumField() { fieldType := val.Type().Field(i) field := val.Field(i) diff --git a/service/netquery/orm/query_runner.go b/service/netquery/orm/query_runner.go index 135a29f61..6eff986ce 100644 --- a/service/netquery/orm/query_runner.go +++ b/service/netquery/orm/query_runner.go @@ -145,7 +145,7 @@ func RunQuery(ctx context.Context, conn *sqlite.Conn, sql string, modifiers ...Q if err := DecodeStmt(ctx, &args.Schema, stmt, currentField.Interface(), args.DecodeConfig); err != nil { resultDump := make(map[string]any) - for colIdx := 0; colIdx < stmt.ColumnCount(); colIdx++ { + for colIdx := range stmt.ColumnCount() { name := stmt.ColumnName(colIdx) switch stmt.ColumnType(colIdx) { //nolint:exhaustive // TODO: handle type BLOB? diff --git a/service/netquery/orm/schema_builder.go b/service/netquery/orm/schema_builder.go index 90805c80e..89381cb82 100644 --- a/service/netquery/orm/schema_builder.go +++ b/service/netquery/orm/schema_builder.go @@ -142,7 +142,7 @@ func GenerateTableSchema(name string, d interface{}) (*TableSchema, error) { return nil, fmt.Errorf("%w, got %T", errStructExpected, d) } - for i := 0; i < val.NumField(); i++ { + for i := range val.NumField() { fieldType := val.Type().Field(i) if !fieldType.IsExported() { continue diff --git a/service/netquery/query_test.go b/service/netquery/query_test.go index 0582aacfd..16a9aab71 100644 --- a/service/netquery/query_test.go +++ b/service/netquery/query_test.go @@ -92,18 +92,17 @@ func TestUnmarshalQuery(t *testing.T) { //nolint:tparallel } for _, testCase := range cases { //nolint:paralleltest - c := testCase - t.Run(c.Name, func(t *testing.T) { + t.Run(testCase.Name, func(t *testing.T) { var q Query - err := json.Unmarshal([]byte(c.Input), &q) + err := json.Unmarshal([]byte(testCase.Input), &q) - if c.Error != nil { + if testCase.Error != nil { if assert.Error(t, err) { - assert.Equal(t, c.Error.Error(), err.Error()) + assert.Equal(t, testCase.Error.Error(), err.Error()) } } else { require.NoError(t, err) - assert.Equal(t, c.Expected, q) + assert.Equal(t, testCase.Expected, q) } }) } @@ -230,20 +229,18 @@ func TestQueryBuilder(t *testing.T) { //nolint:tparallel tbl, err := orm.GenerateTableSchema("connections", Conn{}) require.NoError(t, err) - for idx, testCase := range cases { //nolint:paralleltest - cID := idx - c := testCase - t.Run(c.N, func(t *testing.T) { - str, params, err := c.Q.toSQLWhereClause(context.TODO(), "", tbl, orm.DefaultEncodeConfig) + for cID, testCase := range cases { //nolint:paralleltest + t.Run(testCase.N, func(t *testing.T) { + str, params, err := testCase.Q.toSQLWhereClause(context.TODO(), "", tbl, orm.DefaultEncodeConfig) - if c.E != nil { + if testCase.E != nil { if assert.Error(t, err) { - assert.Equal(t, c.E.Error(), err.Error(), "test case %d", cID) + assert.Equal(t, testCase.E.Error(), err.Error(), "test case %d", cID) } } else { require.NoError(t, err, "test case %d", cID) - assert.Equal(t, c.P, params, "test case %d", cID) - assert.Equal(t, c.R, str, "test case %d", cID) + assert.Equal(t, testCase.P, params, "test case %d", cID) + assert.Equal(t, testCase.R, str, "test case %d", cID) } }) } diff --git a/service/network/netutils/ip.go b/service/network/netutils/ip.go index af316af21..a9d6d0ccb 100644 --- a/service/network/netutils/ip.go +++ b/service/network/netutils/ip.go @@ -157,7 +157,7 @@ func GetBroadcastAddress(ip net.IP, netMask net.IPMask) net.IP { // Merge to broadcast address n := len(ip) broadcastAddress := make(net.IP, n) - for i := 0; i < n; i++ { + for i := range n { broadcastAddress[i] = ip[i] | ^mask[i] } return broadcastAddress diff --git a/service/network/ports.go b/service/network/ports.go index 11c322413..d989b0604 100644 --- a/service/network/ports.go +++ b/service/network/ports.go @@ -13,7 +13,7 @@ func GetUnusedLocalPort(protocol uint8) (port uint16, ok bool) { // Try up to 1000 times to find an unused port. nextPort: - for i := 0; i < tries; i++ { + for i := range tries { // Generate random port between 10000 and 65535 rN, err := rng.Number(55535) if err != nil { diff --git a/service/profile/fingerprint_test.go b/service/profile/fingerprint_test.go index 3cbdb5512..2095cbf40 100644 --- a/service/profile/fingerprint_test.go +++ b/service/profile/fingerprint_test.go @@ -40,7 +40,7 @@ func TestDeriveProfileID(t *testing.T) { rnd := rand.New(rand.NewSource(time.Now().UnixNano())) //nolint:gosec // Test 100 times. - for i := 0; i < 100; i++ { + for range 100 { // Shuffle fingerprints. rnd.Shuffle(len(fps), func(i, j int) { fps[i], fps[j] = fps[j], fps[i] diff --git a/service/resolver/resolver-tcp.go b/service/resolver/resolver-tcp.go index 261f0e5bc..6d02e1643 100644 --- a/service/resolver/resolver-tcp.go +++ b/service/resolver/resolver-tcp.go @@ -358,7 +358,7 @@ func (trc *tcpResolverConn) handler(workerCtx *mgr.WorkerCtx) error { // assignUniqueID makes sure that ID assigned to msg is unique. func (trc *tcpResolverConn) assignUniqueID(msg *dns.Msg) { // try a random ID 10000 times - for i := 0; i < 10000; i++ { // don't try forever + for range 10000 { // don't try forever _, exists := trc.inFlightQueries[msg.Id] if !exists { return // we are unique, yay! diff --git a/service/resolver/resolver_test.go b/service/resolver/resolver_test.go index 5292155b1..e26d4a319 100644 --- a/service/resolver/resolver_test.go +++ b/service/resolver/resolver_test.go @@ -62,7 +62,7 @@ func TestSingleResolving(t *testing.T) { wg := &sync.WaitGroup{} wg.Add(100) - for i := 0; i < 100; i++ { + for range 100 { startQuery(t, wg, resolver.Conn, &Query{ FQDN: <-domainFeed, QType: dns.Type(dns.TypeA), @@ -94,7 +94,7 @@ func TestBulkResolving(t *testing.T) { wg := &sync.WaitGroup{} wg.Add(100) - for i := 0; i < 100; i++ { + for range 100 { go startQuery(t, wg, resolver.Conn, &Query{ FQDN: <-domainFeed, QType: dns.Type(dns.TypeA), diff --git a/service/status/status.go b/service/status/status.go index 7a7ce8896..0e78a8eb3 100644 --- a/service/status/status.go +++ b/service/status/status.go @@ -84,11 +84,9 @@ func (s *Status) buildSystemStatus() *SystemStatusRecord { OnlineStatus: netenv.GetOnlineStatus(), Modules: make([]mgr.StateUpdate, 0, len(s.states)), } - for _, v := range s.states { + for _, newStateUpdate := range s.states { // Deep copy state. - newStateUpdate := v - newStateUpdate.States = make([]mgr.State, len(v.States)) - copy(newStateUpdate.States, v.States) + newStateUpdate.States = append([]mgr.State(nil), newStateUpdate.States...) status.Modules = append(status.Modules, newStateUpdate) // Check if state is worst so far. diff --git a/spn/access/client_test.go b/spn/access/client_test.go index 93c5e81e5..acc5bd2b8 100644 --- a/spn/access/client_test.go +++ b/spn/access/client_test.go @@ -63,7 +63,7 @@ func loginAndRefresh(t *testing.T, doLogin bool, refreshTimes int) { t.Logf("auth token: %+v", authToken.Token) } - for i := 0; i < refreshTimes; i++ { + for range refreshTimes { user, _, err := UpdateUser() if err != nil { t.Fatalf("getting profile failed: %s", err) diff --git a/spn/access/token/pblind.go b/spn/access/token/pblind.go index 1342a2831..ef97bb78f 100644 --- a/spn/access/token/pblind.go +++ b/spn/access/token/pblind.go @@ -217,7 +217,7 @@ func (pbh *PBlindHandler) CreateSetup() (state *PBlindSignerState, setupResponse } // Go through the batch. - for i := 0; i < pbh.opts.BatchSize; i++ { + for i := range pbh.opts.BatchSize { info, err := pbh.makeInfo(i + 1) if err != nil { return nil, nil, fmt.Errorf("failed to create info #%d: %w", i, err) @@ -257,7 +257,7 @@ func (pbh *PBlindHandler) CreateTokenRequest(requestSetup *PBlindSetupResponse) } // Go through the batch. - for i := 0; i < pbh.opts.BatchSize; i++ { + for i := range pbh.opts.BatchSize { // Check if we have setup data. if requestSetup.Msgs[i] == nil { return nil, fmt.Errorf("missing setup data #%d", i) @@ -319,7 +319,7 @@ func (pbh *PBlindHandler) IssueTokens(state *PBlindSignerState, request *PBlindT } // Go through the batch. - for i := 0; i < pbh.opts.BatchSize; i++ { + for i := range pbh.opts.BatchSize { // Check if we have request data. if request.Msgs[i] == nil { return nil, fmt.Errorf("missing request data #%d", i) @@ -360,7 +360,7 @@ func (pbh *PBlindHandler) ProcessIssuedTokens(issuedTokens *IssuedPBlindTokens) finalizedTokens := make([]*PBlindToken, pbh.opts.BatchSize) // Go through the batch. - for i := 0; i < pbh.opts.BatchSize; i++ { + for i := range pbh.opts.BatchSize { // Finalize token. err := pbh.requestState[i].State.ProcessMessage3(*issuedTokens.Msgs[i]) if err != nil { diff --git a/spn/access/token/pblind_test.go b/spn/access/token/pblind_test.go index b25ac71be..b775895df 100644 --- a/spn/access/token/pblind_test.go +++ b/spn/access/token/pblind_test.go @@ -124,7 +124,7 @@ func TestPBlindLibrary(t *testing.T) { // Create signers and prep requests. start := time.Now() - for i := 0; i < batchSize; i++ { + for i := range batchSize { signer, err := pblind.CreateSigner(sk, info) if err != nil { t.Fatal(err) @@ -146,7 +146,7 @@ func TestPBlindLibrary(t *testing.T) { // Create requesters and create requests. start = time.Now() - for i := 0; i < batchSize; i++ { + for i := range batchSize { requester, err := pblind.CreateRequester(pk, info, msgStr) if err != nil { t.Fatal(err) @@ -178,7 +178,7 @@ func TestPBlindLibrary(t *testing.T) { // Sign requests start = time.Now() - for i := 0; i < batchSize; i++ { + for i := range batchSize { var msg2S pblind.Message2 _, err = asn1.Unmarshal(toServer[i], &msg2S) if err != nil { @@ -204,7 +204,7 @@ func TestPBlindLibrary(t *testing.T) { // Verify signed requests start = time.Now() - for i := 0; i < batchSize; i++ { + for i := range batchSize { var msg3R pblind.Message3 _, err := asn1.Unmarshal(toClient[i], &msg3R) if err != nil { @@ -234,7 +234,7 @@ func TestPBlindLibrary(t *testing.T) { // Verify on server start = time.Now() - for i := 0; i < batchSize; i++ { + for i := range batchSize { var sig pblind.Signature _, err := asn1.Unmarshal(toServer[i], &sig) if err != nil { diff --git a/spn/cabin/keys.go b/spn/cabin/keys.go index 665c6d656..3faa36719 100644 --- a/spn/cabin/keys.go +++ b/spn/cabin/keys.go @@ -144,7 +144,7 @@ func (id *Identity) MaintainExchKeys(newStatus *hub.Status, now time.Time) (chan func (id *Identity) createExchKey(eks *providedExchKeyScheme, now time.Time) error { // get ID var keyID string - for i := 0; i < 1000000; i++ { // not forever + for range 1000000 { // not forever // generate new ID b, err := rng.Bytes(3) if err != nil { diff --git a/spn/cabin/keys_test.go b/spn/cabin/keys_test.go index 4d135f012..98b5db285 100644 --- a/spn/cabin/keys_test.go +++ b/spn/cabin/keys_test.go @@ -19,7 +19,7 @@ func TestKeyMaintenance(t *testing.T) { changeCnt := 0 now := time.Now() - for i := 0; i < iterations; i++ { + for range iterations { changed, err := id.MaintainExchKeys(id.Hub.Status, now) if err != nil { t.Fatal(err) diff --git a/spn/crew/op_connect_test.go b/spn/crew/op_connect_test.go index 9a7e24f01..010bca0f0 100644 --- a/spn/crew/op_connect_test.go +++ b/spn/crew/op_connect_test.go @@ -61,8 +61,7 @@ func TestConnectOp(t *testing.T) { t.Fatalf("failed to update identity: %s", err) } EnableConnecting(identity.Hub) - - for i := 0; i < 1; i++ { + { appConn, sluiceConn := net.Pipe() _, tErr := NewConnectOp(&Tunnel{ connInfo: &network.Connection{ diff --git a/spn/docks/module.go b/spn/docks/module.go index ceb97a5a3..2873879d4 100644 --- a/spn/docks/module.go +++ b/spn/docks/module.go @@ -50,7 +50,7 @@ func registerCrane(crane *Crane) error { defer cranesLock.Unlock() // Generate new IDs until a unique one is found. - for i := 0; i < 100; i++ { + for range 100 { // Generate random ID. randomID, err := rng.Bytes(3) if err != nil { diff --git a/spn/hub/hub.go b/spn/hub/hub.go index 0caac39b7..6817ce2c0 100644 --- a/spn/hub/hub.go +++ b/spn/hub/hub.go @@ -425,7 +425,7 @@ func equalStringSlice(a, b []string) bool { return false } - for i := 0; i < len(a); i++ { + for i := range len(a) { if a[i] != b[i] { return false } diff --git a/spn/navigator/findnearest_test.go b/spn/navigator/findnearest_test.go index 596d7779b..5f09ea35f 100644 --- a/spn/navigator/findnearest_test.go +++ b/spn/navigator/findnearest_test.go @@ -12,7 +12,7 @@ func TestFindNearest(t *testing.T) { fakeLock.Lock() defer fakeLock.Unlock() - for i := 0; i < 100; i++ { + for range 100 { // Create a random destination address ip4, loc4 := createGoodIP(true) @@ -24,7 +24,7 @@ func TestFindNearest(t *testing.T) { } } - for i := 0; i < 100; i++ { + for range 100 { // Create a random destination address ip6, loc6 := createGoodIP(true) diff --git a/spn/navigator/findroutes_test.go b/spn/navigator/findroutes_test.go index ed7793c1e..2056c0a16 100644 --- a/spn/navigator/findroutes_test.go +++ b/spn/navigator/findroutes_test.go @@ -13,7 +13,7 @@ func TestFindRoutes(t *testing.T) { fakeLock.Lock() defer fakeLock.Unlock() - for i := 0; i < 1; i++ { + for i := range 1 { // Create a random destination address dstIP, _ := createGoodIP(i%2 == 0) @@ -37,13 +37,13 @@ func BenchmarkFindRoutes(b *testing.B) { // Pre-generate 100 IPs preGenIPs := make([]net.IP, 0, 100) - for i := 0; i < cap(preGenIPs); i++ { + for i := range cap(preGenIPs) { ip, _ := createGoodIP(i%2 == 0) preGenIPs = append(preGenIPs, ip) } b.ResetTimer() - for i := 0; i < b.N; i++ { + for i := range b.N { routes, err := m.FindRoutes(preGenIPs[i%len(preGenIPs)], m.DefaultOptions()) if err != nil { b.Error(err) diff --git a/spn/navigator/map_test.go b/spn/navigator/map_test.go index bea2d4779..99e911882 100644 --- a/spn/navigator/map_test.go +++ b/spn/navigator/map_test.go @@ -58,7 +58,7 @@ func createRandomTestMap(seed int64, size int) *Map { } // Create Hub list. - var hubs []*hub.Hub + hubs := make([]*hub.Hub, 0, size) // Create Intel data structure. mapIntel := &hub.Intel{ @@ -69,7 +69,7 @@ func createRandomTestMap(seed int64, size int) *Map { var currentGroup string // Create [size] fake Hubs. - for i := 0; i < size; i++ { + for i := range size { // Change group every 5 Hubs. if i%5 == 0 { currentGroup = gofakeit.Username() @@ -81,7 +81,7 @@ func createRandomTestMap(seed int64, size int) *Map { } // Fake three superseeded Hubs. - for i := 0; i < 3; i++ { + for i := range 3 { h := hubs[size-1-i] // Set FirstSeen in the past and copy an IP address of an existing Hub. @@ -95,7 +95,7 @@ func createRandomTestMap(seed int64, size int) *Map { // Create Lanes between Hubs in order to create the network. totalConnections := size * 10 - for i := 0; i < totalConnections; i++ { + for range totalConnections { // Get new random indexes. indexA := gofakeit.Number(0, size-1) indexB := gofakeit.Number(0, size-1) @@ -246,7 +246,7 @@ func createFakeHub(group string, randomFailes bool, mapIntel *hub.Intel) *hub.Hu func createGoodIP(v4 bool) (net.IP, *geoip.Location) { var candidate net.IP - for i := 0; i < 100; i++ { + for range 100 { if v4 { candidate = net.ParseIP(gofakeit.IPv4Address()) } else { diff --git a/spn/navigator/module.go b/spn/navigator/module.go index 2568cb7e3..b2e4bba8c 100644 --- a/spn/navigator/module.go +++ b/spn/navigator/module.go @@ -90,7 +90,7 @@ func start() error { // The "wait" parameter times out after 1 second. // Allow 30 seconds for both databases to load. geoInitCheck: - for i := 0; i < 30; i++ { + for range 30 { switch { case !geoip.IsInitialized(false, true): // First, IPv4. case !geoip.IsInitialized(true, true): // Then, IPv6. diff --git a/spn/navigator/optimize.go b/spn/navigator/optimize.go index 76f101c37..f396b344e 100644 --- a/spn/navigator/optimize.go +++ b/spn/navigator/optimize.go @@ -373,7 +373,7 @@ func (m *Map) optimizeForDistanceConstraint(result *OptimizationResult, max int) // Add approach. result.addApproach(fmt.Sprintf("Satisfy max hop constraint of %d globally.", optimizationHopDistanceTarget)) - for i := 0; i < max; i++ { + for range max { // Sort by lowest cost. sort.Sort(sortBySuggestedHopDistanceAndLowestMeasuredCost(m.regardedPins)) diff --git a/spn/patrol/http.go b/spn/patrol/http.go index c3bb14bdd..9a936fed9 100644 --- a/spn/patrol/http.go +++ b/spn/patrol/http.go @@ -84,7 +84,7 @@ func checkHTTPSConnectivity(ctx context.Context, network string, checks int, req // Run tests. var succeeded int - for i := 0; i < checks; i++ { + for range checks { if checkHTTPSConnection(ctx, network) { succeeded++ } diff --git a/spn/ships/connection_test.go b/spn/ships/connection_test.go index 5d03927bf..d41726459 100644 --- a/spn/ships/connection_test.go +++ b/spn/ships/connection_test.go @@ -35,8 +35,7 @@ func TestConnections(t *testing.T) { registryLock.Unlock() }) - for k, v := range registry { //nolint:paralleltest // False positive. - protocol, builder := k, v + for protocol, builder := range registry { t.Run(protocol, func(t *testing.T) { t.Parallel() @@ -85,7 +84,7 @@ func TestConnections(t *testing.T) { assert.Equal(t, testData, buf, "should match") fmt.Print(".") - for i := 0; i < 100; i++ { + for range 100 { // server send err = srvShip.Load(testData) if err != nil { diff --git a/spn/ships/http_shared.go b/spn/ships/http_shared.go index bae861e57..b174dfc2c 100644 --- a/spn/ships/http_shared.go +++ b/spn/ships/http_shared.go @@ -121,10 +121,9 @@ func addHTTPHandler(port uint16, path string, handler http.HandlerFunc) error { sharedHTTPServers[port] = shared // Start servers in service workers. - for _, listener := range listeners { - serviceListener := listener + for _, serviceListener := range listeners { module.mgr.Go( - fmt.Sprintf("shared http server listener on %s", listener.Addr()), + fmt.Sprintf("shared http server listener on %s", serviceListener.Addr()), func(_ *mgr.WorkerCtx) error { err := shared.server.Serve(serviceListener) if !errors.Is(http.ErrServerClosed, err) { diff --git a/spn/ships/testship_test.go b/spn/ships/testship_test.go index 7e026b92a..a0c117318 100644 --- a/spn/ships/testship_test.go +++ b/spn/ships/testship_test.go @@ -17,7 +17,7 @@ func TestTestShip(t *testing.T) { srvShip := tShip.Reverse() - for i := 0; i < 100; i++ { + for range 100 { // client send err := ship.Load(testData) if err != nil { diff --git a/spn/terminal/session_test.go b/spn/terminal/session_test.go index e61d1f526..9ddfec34f 100644 --- a/spn/terminal/session_test.go +++ b/spn/terminal/session_test.go @@ -16,7 +16,7 @@ func TestRateLimit(t *testing.T) { s := NewSession() // Everything should be okay within the min limit. - for i := 0; i < rateLimitMinOps; i++ { + for range rateLimitMinOps { tErr = s.RateLimit() if tErr != nil { t.Error("should not rate limit within min limit") @@ -24,7 +24,7 @@ func TestRateLimit(t *testing.T) { } // Somewhere here we should rate limiting. - for i := 0; i < rateLimitMaxOpsPerSecond; i++ { + for range rateLimitMaxOpsPerSecond { tErr = s.RateLimit() } assert.ErrorIs(t, tErr, ErrRateLimited, "should rate limit") @@ -37,7 +37,7 @@ func TestSuspicionLimit(t *testing.T) { s := NewSession() // Everything should be okay within the min limit. - for i := 0; i < rateLimitMinSuspicion; i++ { + for range rateLimitMinSuspicion { tErr = s.RateLimit() if tErr != nil { t.Error("should not rate limit within min limit") @@ -46,7 +46,7 @@ func TestSuspicionLimit(t *testing.T) { } // Somewhere here we should rate limiting. - for i := 0; i < rateLimitMaxSuspicionPerSecond; i++ { + for range rateLimitMaxSuspicionPerSecond { s.ReportSuspiciousActivity(SusFactorCommon) tErr = s.RateLimit() } @@ -66,8 +66,7 @@ func TestConcurrencyLimit(t *testing.T) { // Start many workers to test concurrency. wg.Add(workers) - for i := 0; i < workers; i++ { - workerNum := i + for workerNum := range workers { go func() { defer func() { _ = recover() diff --git a/spn/unit/unit_test.go b/spn/unit/unit_test.go index 636f1231f..010e9df75 100644 --- a/spn/unit/unit_test.go +++ b/spn/unit/unit_test.go @@ -39,9 +39,9 @@ func TestUnit(t *testing.T) { //nolint:paralleltest var wg sync.WaitGroup wg.Add(workers) sizePerWorker := size / workers - for i := 0; i < workers; i++ { + for range workers { go func() { - for i := 0; i < sizePerWorker; i++ { + for range sizePerWorker { u := s.NewUnit() // Make 1% high priority.