internal/postgres: use moduleID in getModuleLicenses
getModuleLicenses now accepts moduleID instead of module path and
version as args.
In a later CL, we will be dropping licenses.module_path and
licenses.module_id.
For golang/go#39629
Change-Id: I18c525c584fd181372ef6c01241a90c19cb28a96
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/271747
Trust: Julie Qiu <julie@golang.org>
Run-TryBot: Julie Qiu <julie@golang.org>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/internal/postgres/insert_module.go b/internal/postgres/insert_module.go
index 9e81744..3906554 100644
--- a/internal/postgres/insert_module.go
+++ b/internal/postgres/insert_module.go
@@ -23,6 +23,7 @@
"golang.org/x/pkgsite/internal"
"golang.org/x/pkgsite/internal/database"
"golang.org/x/pkgsite/internal/derrors"
+ "golang.org/x/pkgsite/internal/licenses"
"golang.org/x/pkgsite/internal/log"
"golang.org/x/pkgsite/internal/stdlib"
"golang.org/x/pkgsite/internal/version"
@@ -51,9 +52,6 @@
// inserted. Rows that currently exist should not be missing from the
// new module. We want to be sure that we will overwrite every row that
// pertains to the module.
- if err := db.compareLicenses(ctx, m); err != nil {
- return err
- }
if err := db.comparePaths(ctx, m); err != nil {
return err
}
@@ -81,6 +79,13 @@
if err != nil {
return err
}
+ // Compare existing data from the database, and the module to be
+ // inserted. Rows that currently exist should not be missing from the
+ // new module. We want to be sure that we will overwrite every row that
+ // pertains to the module.
+ if err := db.compareLicenses(ctx, moduleID, m.Licenses); err != nil {
+ return err
+ }
if err := insertLicenses(ctx, tx, m, moduleID); err != nil {
return err
}
@@ -561,15 +566,15 @@
// compareLicenses compares m.Licenses with the existing licenses for
// m.ModulePath and m.Version in the database. It returns an error if there
// are licenses in the licenses table that are not present in m.Licenses.
-func (db *DB) compareLicenses(ctx context.Context, m *internal.Module) (err error) {
- defer derrors.Wrap(&err, "compareLicenses(ctx, %q, %q)", m.ModulePath, m.Version)
- dbLicenses, err := db.getModuleLicenses(ctx, m.ModulePath, m.Version)
+func (db *DB) compareLicenses(ctx context.Context, moduleID int, lics []*licenses.License) (err error) {
+ defer derrors.Wrap(&err, "compareLicenses(ctx, %d)", moduleID)
+ dbLicenses, err := db.getModuleLicenses(ctx, moduleID)
if err != nil {
return err
}
set := map[string]bool{}
- for _, l := range m.Licenses {
+ for _, l := range lics {
set[l.FilePath] = true
}
for _, l := range dbLicenses {
diff --git a/internal/postgres/licenses.go b/internal/postgres/licenses.go
index b5aecdb..87755a9 100644
--- a/internal/postgres/licenses.go
+++ b/internal/postgres/licenses.go
@@ -80,21 +80,18 @@
// getModuleLicenses returns all licenses associated with the given module path and
// version. These are the top-level licenses in the module zip file.
// It returns an InvalidArgument error if the module path or version is invalid.
-func (db *DB) getModuleLicenses(ctx context.Context, modulePath, resolvedVersion string) (_ []*licenses.License, err error) {
- defer derrors.Wrap(&err, "getModuleLicenses(ctx, %q, %q)", modulePath, resolvedVersion)
+func (db *DB) getModuleLicenses(ctx context.Context, moduleID int) (_ []*licenses.License, err error) {
+ defer derrors.Wrap(&err, "getModuleLicenses(ctx, %d)", moduleID)
- if modulePath == "" || resolvedVersion == "" {
- return nil, fmt.Errorf("neither modulePath nor version can be empty: %w", derrors.InvalidArgument)
- }
query := `
SELECT
types, file_path, contents, coverage
FROM
licenses
WHERE
- module_path = $1 AND version = $2 AND position('/' in file_path) = 0
+ module_id = $1 AND position('/' in file_path) = 0
`
- rows, err := db.db.Query(ctx, query, modulePath, resolvedVersion)
+ rows, err := db.db.Query(ctx, query, moduleID)
if err != nil {
return nil, err
}
diff --git a/internal/postgres/licenses_test.go b/internal/postgres/licenses_test.go
index d18e6c9..6e52de0 100644
--- a/internal/postgres/licenses_test.go
+++ b/internal/postgres/licenses_test.go
@@ -144,7 +144,17 @@
t.Fatal(err)
}
- got, err := testDB.getModuleLicenses(ctx, modulePath, testModule.Version)
+ var moduleID int
+ query := `
+ SELECT m.id
+ FROM modules m
+ WHERE
+ m.module_path = $1
+ AND m.version = $2;`
+ if err := testDB.db.QueryRow(ctx, query, modulePath, testModule.Version).Scan(&moduleID); err != nil {
+ t.Fatal(err)
+ }
+ got, err := testDB.getModuleLicenses(ctx, moduleID)
if err != nil {
t.Fatal(err)
}