internal/postgres: use moduleID in getModuleLicenses

getModuleLicenses now accepts moduleID instead of module path and
version as args.

In a later CL, we will be dropping licenses.module_path and
licenses.module_id.

For golang/go#39629

Change-Id: I18c525c584fd181372ef6c01241a90c19cb28a96
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/271747
Trust: Julie Qiu <julie@golang.org>
Run-TryBot: Julie Qiu <julie@golang.org>
TryBot-Result: kokoro <noreply+kokoro@google.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
diff --git a/internal/postgres/insert_module.go b/internal/postgres/insert_module.go
index 9e81744..3906554 100644
--- a/internal/postgres/insert_module.go
+++ b/internal/postgres/insert_module.go
@@ -23,6 +23,7 @@
 	"golang.org/x/pkgsite/internal"
 	"golang.org/x/pkgsite/internal/database"
 	"golang.org/x/pkgsite/internal/derrors"
+	"golang.org/x/pkgsite/internal/licenses"
 	"golang.org/x/pkgsite/internal/log"
 	"golang.org/x/pkgsite/internal/stdlib"
 	"golang.org/x/pkgsite/internal/version"
@@ -51,9 +52,6 @@
 	// inserted. Rows that currently exist should not be missing from the
 	// new module. We want to be sure that we will overwrite every row that
 	// pertains to the module.
-	if err := db.compareLicenses(ctx, m); err != nil {
-		return err
-	}
 	if err := db.comparePaths(ctx, m); err != nil {
 		return err
 	}
@@ -81,6 +79,13 @@
 		if err != nil {
 			return err
 		}
+		// Compare existing data from the database, and the module to be
+		// inserted. Rows that currently exist should not be missing from the
+		// new module. We want to be sure that we will overwrite every row that
+		// pertains to the module.
+		if err := db.compareLicenses(ctx, moduleID, m.Licenses); err != nil {
+			return err
+		}
 		if err := insertLicenses(ctx, tx, m, moduleID); err != nil {
 			return err
 		}
@@ -561,15 +566,15 @@
 // compareLicenses compares m.Licenses with the existing licenses for
 // m.ModulePath and m.Version in the database. It returns an error if there
 // are licenses in the licenses table that are not present in m.Licenses.
-func (db *DB) compareLicenses(ctx context.Context, m *internal.Module) (err error) {
-	defer derrors.Wrap(&err, "compareLicenses(ctx, %q, %q)", m.ModulePath, m.Version)
-	dbLicenses, err := db.getModuleLicenses(ctx, m.ModulePath, m.Version)
+func (db *DB) compareLicenses(ctx context.Context, moduleID int, lics []*licenses.License) (err error) {
+	defer derrors.Wrap(&err, "compareLicenses(ctx, %d)", moduleID)
+	dbLicenses, err := db.getModuleLicenses(ctx, moduleID)
 	if err != nil {
 		return err
 	}
 
 	set := map[string]bool{}
-	for _, l := range m.Licenses {
+	for _, l := range lics {
 		set[l.FilePath] = true
 	}
 	for _, l := range dbLicenses {
diff --git a/internal/postgres/licenses.go b/internal/postgres/licenses.go
index b5aecdb..87755a9 100644
--- a/internal/postgres/licenses.go
+++ b/internal/postgres/licenses.go
@@ -80,21 +80,18 @@
 // getModuleLicenses returns all licenses associated with the given module path and
 // version. These are the top-level licenses in the module zip file.
 // It returns an InvalidArgument error if the module path or version is invalid.
-func (db *DB) getModuleLicenses(ctx context.Context, modulePath, resolvedVersion string) (_ []*licenses.License, err error) {
-	defer derrors.Wrap(&err, "getModuleLicenses(ctx, %q, %q)", modulePath, resolvedVersion)
+func (db *DB) getModuleLicenses(ctx context.Context, moduleID int) (_ []*licenses.License, err error) {
+	defer derrors.Wrap(&err, "getModuleLicenses(ctx, %d)", moduleID)
 
-	if modulePath == "" || resolvedVersion == "" {
-		return nil, fmt.Errorf("neither modulePath nor version can be empty: %w", derrors.InvalidArgument)
-	}
 	query := `
 	SELECT
 		types, file_path, contents, coverage
 	FROM
 		licenses
 	WHERE
-		module_path = $1 AND version = $2 AND position('/' in file_path) = 0
+		module_id = $1 AND position('/' in file_path) = 0
     `
-	rows, err := db.db.Query(ctx, query, modulePath, resolvedVersion)
+	rows, err := db.db.Query(ctx, query, moduleID)
 	if err != nil {
 		return nil, err
 	}
diff --git a/internal/postgres/licenses_test.go b/internal/postgres/licenses_test.go
index d18e6c9..6e52de0 100644
--- a/internal/postgres/licenses_test.go
+++ b/internal/postgres/licenses_test.go
@@ -144,7 +144,17 @@
 		t.Fatal(err)
 	}
 
-	got, err := testDB.getModuleLicenses(ctx, modulePath, testModule.Version)
+	var moduleID int
+	query := `
+		SELECT m.id
+		FROM modules m
+		WHERE
+		    m.module_path = $1
+		    AND m.version = $2;`
+	if err := testDB.db.QueryRow(ctx, query, modulePath, testModule.Version).Scan(&moduleID); err != nil {
+		t.Fatal(err)
+	}
+	got, err := testDB.getModuleLicenses(ctx, moduleID)
 	if err != nil {
 		t.Fatal(err)
 	}